move port choice, and PortForwarder.Start call, into clients
This commit is contained in:
parent
509e037a5e
commit
c0aae52289
4 changed files with 93 additions and 70 deletions
|
|
@ -59,7 +59,12 @@ func logs(ctx context.Context, tail bool, codespaceName string) error {
|
||||||
return fmt.Errorf("connecting to liveshare: %v", err)
|
return fmt.Errorf("connecting to liveshare: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tunnelPort, connClosed, err := codespaces.StartPortForwarding(ctx, lsclient, "sshd", 0)
|
port, err := codespaces.UnusedPort()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("make ssh tunnel: %v", err)
|
return fmt.Errorf("make ssh tunnel: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -71,23 +76,29 @@ func logs(ctx context.Context, tail bool, codespaceName string) error {
|
||||||
|
|
||||||
dst := fmt.Sprintf("%s@localhost", getSSHUser(codespace))
|
dst := fmt.Sprintf("%s@localhost", getSSHUser(codespace))
|
||||||
cmd := codespaces.NewRemoteCommand(
|
cmd := codespaces.NewRemoteCommand(
|
||||||
ctx, tunnelPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
|
ctx, port, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Channel is buffered to avoid a goroutine leak when connClosed occurs before done.
|
// Error channels are buffered so that neither sending goroutine gets stuck.
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() { done <- cmd.Run() }()
|
tunnelClosed := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
tunnelClosed <- tunnel.Start(ctx) // error is non-nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
cmdDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
cmdDone <- cmd.Run()
|
||||||
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-connClosed:
|
case err := <-tunnelClosed:
|
||||||
if err != nil {
|
return fmt.Errorf("connection closed: %v", err)
|
||||||
return fmt.Errorf("connection closed: %v", err)
|
|
||||||
}
|
case err := <-cmdDone:
|
||||||
case err := <-done:
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error retrieving logs: %v", err)
|
return fmt.Errorf("error retrieving logs: %v", err)
|
||||||
}
|
}
|
||||||
|
return nil // success
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,17 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, sshServerPort in
|
||||||
}
|
}
|
||||||
log.Print("\n")
|
log.Print("\n")
|
||||||
|
|
||||||
tunnelPort, tunnelClosed, err := codespaces.StartPortForwarding(ctx, lsclient, "sshd", sshServerPort)
|
usingCustomPort := true
|
||||||
|
if sshServerPort == 0 {
|
||||||
|
usingCustomPort = false // suppress log of command line in Shell
|
||||||
|
port, err := codespaces.UnusedPort()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sshServerPort = port
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", sshServerPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("make ssh tunnel: %v", err)
|
return fmt.Errorf("make ssh tunnel: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -90,26 +100,27 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, sshServerPort in
|
||||||
connectDestination = fmt.Sprintf("%s@localhost", getSSHUser(codespace))
|
connectDestination = fmt.Sprintf("%s@localhost", getSSHUser(codespace))
|
||||||
}
|
}
|
||||||
|
|
||||||
usingCustomPort := tunnelPort == sshServerPort
|
tunnelClosed := make(chan error)
|
||||||
|
go func() {
|
||||||
|
tunnelClosed <- tunnel.Start(ctx) // error is always non-nil
|
||||||
|
}()
|
||||||
|
|
||||||
shellClosed := make(chan error)
|
shellClosed := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
shellClosed <- codespaces.Shell(ctx, log, tunnelPort, connectDestination, usingCustomPort)
|
shellClosed <- codespaces.Shell(ctx, log, sshServerPort, connectDestination, usingCustomPort)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Println("Ready...")
|
log.Println("Ready...")
|
||||||
select {
|
select {
|
||||||
case err := <-tunnelClosed:
|
case err := <-tunnelClosed:
|
||||||
if err != nil {
|
return fmt.Errorf("tunnel closed: %v", err)
|
||||||
return fmt.Errorf("tunnel closed: %v", err)
|
|
||||||
}
|
|
||||||
case err := <-shellClosed:
|
case err := <-shellClosed:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("shell closed: %v", err)
|
return fmt.Errorf("shell closed: %v", err)
|
||||||
}
|
}
|
||||||
|
return nil // success
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) {
|
func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) {
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ package codespaces
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
@ -12,57 +12,47 @@ import (
|
||||||
"github.com/github/go-liveshare"
|
"github.com/github/go-liveshare"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StartPortForwarding starts LiveShare port forwarding of traffic
|
// UnusedPort returns the number of a local TCP port that is currently
|
||||||
// between the LiveShare client and the specified local port, or, if
|
// unbound, or an error if none was available.
|
||||||
// zero, a port chosen at random; the effective port number is
|
//
|
||||||
// returned. Forwarding continues in the background until an error is
|
// Use of this function carries an inherent risk of a time-of-check to
|
||||||
// encountered (including cancellation of the context). Therefore
|
// time-of-use race against other processes.
|
||||||
// clients must cancel the context.
|
func UnusedPort() (int, error) {
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("internal error while choosing port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := net.ListenTCP("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("choosing available port: %v", err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
return l.Addr().(*net.TCPAddr).Port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPortForwarder returns a new port forwarder for traffic between
|
||||||
|
// the Live Share client and the specified local port (which must be
|
||||||
|
// available).
|
||||||
//
|
//
|
||||||
// The session name is used (along with the port) to generate
|
// The session name is used (along with the port) to generate
|
||||||
// names for streams, and may appear in error messages.
|
// names for streams, and may appear in error messages.
|
||||||
//
|
func NewPortForwarder(ctx context.Context, client *liveshare.Client, sessionName string, localPort int) (*liveshare.PortForwarder, error) {
|
||||||
// TODO(adonovan): simplify API concurrency from API. Either:
|
if localPort == 0 {
|
||||||
// 1) return a stop function so that clients don't forget to stop forwarding.
|
return nil, fmt.Errorf("a local port must be provided")
|
||||||
// 2) avoid creating a goroutine and returning a channel. Use approach of
|
|
||||||
// http.ListenAndServe, which runs until it encounters an error
|
|
||||||
// (incl. cancellation). But this means we can't return the port.
|
|
||||||
// Can we make the client responsible for supplying it?
|
|
||||||
// 3) return a PortForwarding object that encapsulates the port,
|
|
||||||
// and has NewRemoteCommand as a method. It will need a Stop method,
|
|
||||||
// and an Error method for querying whether the session has failed
|
|
||||||
// asynchronously.
|
|
||||||
func StartPortForwarding(ctx context.Context, lsclient *liveshare.Client, sessionName string, localPort int) (int, <-chan error, error) {
|
|
||||||
server, err := liveshare.NewServer(lsclient)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, fmt.Errorf("new liveshare server: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if localPort == 0 {
|
server, err := liveshare.NewServer(client)
|
||||||
localPort = rand.Intn(9999-2000) + 2000
|
if err != nil {
|
||||||
// TODO(adonovan): retry if port is taken?
|
return nil, fmt.Errorf("new liveshare server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(josebalius): This port won't always be 2222
|
// TODO(josebalius): This port won't always be 2222
|
||||||
if err := server.StartSharing(ctx, sessionName, 2222); err != nil {
|
if err := server.StartSharing(ctx, sessionName, 2222); err != nil {
|
||||||
return 0, nil, fmt.Errorf("sharing sshd port: %v", err)
|
return nil, fmt.Errorf("sharing sshd port: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tunnelClosed := make(chan error)
|
return liveshare.NewPortForwarder(client, server, localPort), nil
|
||||||
go func() {
|
|
||||||
// TODO(adonovan): simplify liveshare API to combine NewPortForwarder and Start
|
|
||||||
// methods into a single ForwardPort call, like http.ListenAndServe.
|
|
||||||
// (Start is a misnomer: it runs the complete session.)
|
|
||||||
// Also document that it never returns a nil error.
|
|
||||||
portForwarder := liveshare.NewPortForwarder(lsclient, server, localPort)
|
|
||||||
if err := portForwarder.Start(ctx); err != nil {
|
|
||||||
tunnelClosed <- fmt.Errorf("forwarding port: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tunnelClosed <- nil
|
|
||||||
}()
|
|
||||||
|
|
||||||
return localPort, tunnelClosed, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shell runs an interactive secure shell over an existing
|
// Shell runs an interactive secure shell over an existing
|
||||||
|
|
@ -78,8 +68,8 @@ func Shell(ctx context.Context, log logger, port int, destination string, usingC
|
||||||
return cmd.Run()
|
return cmd.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRemoteCommand returns a partially populated exec.Cmd that will
|
// NewRemoteCommand returns an exec.Cmd that will securely run a shell
|
||||||
// securely run a shell command on the remote machine.
|
// command on the remote machine.
|
||||||
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd {
|
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd {
|
||||||
cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command)
|
cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command)
|
||||||
return cmd
|
return cmd
|
||||||
|
|
@ -92,7 +82,6 @@ func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cm
|
||||||
// TODO(adonovan): eliminate X11 and X11Trust flags where unneeded.
|
// TODO(adonovan): eliminate X11 and X11Trust flags where unneeded.
|
||||||
cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression
|
cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression
|
||||||
|
|
||||||
// An empty command enables port forwarding but not execution.
|
|
||||||
if command != "" {
|
if command != "" {
|
||||||
cmdArgs = append(cmdArgs, command)
|
cmdArgs = append(cmdArgs, command)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -45,22 +45,34 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
|
||||||
return fmt.Errorf("connect to liveshare: %v", err)
|
return fmt.Errorf("connect to liveshare: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tunnelPort, connClosed, err := StartPortForwarding(ctx, lsclient, "sshd", 0)
|
port, err := UnusedPort()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("make ssh tunnel: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fwd, err := NewPortForwarder(ctx, lsclient, "sshd", port)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating port forwarder: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
|
||||||
|
go func() {
|
||||||
|
tunnelClosed <- fwd.Start(ctx) // error is non-nil
|
||||||
|
}()
|
||||||
|
|
||||||
t := time.NewTicker(1 * time.Second)
|
t := time.NewTicker(1 * time.Second)
|
||||||
defer t.Stop()
|
defer t.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil // canceled
|
||||||
case err := <-connClosed:
|
|
||||||
return fmt.Errorf("connection closed: %v", err)
|
case err := <-tunnelClosed:
|
||||||
|
return fmt.Errorf("connection failed: %v", err)
|
||||||
|
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
states, err := getPostCreateOutput(ctx, tunnelPort, codespace)
|
states, err := getPostCreateOutput(ctx, port, codespace)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get post create output: %v", err)
|
return fmt.Errorf("get post create output: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue