diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 49acb3449..590596603 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "net" "os" "github.com/github/ghcs/api" @@ -60,10 +61,13 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { return fmt.Errorf("connecting to Live Share: %v", err) } - localSSHPort, err := codespaces.UnusedPort() + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := liveshare.Listen(0) // zero => arbitrary if err != nil { return err } + defer listen.Close() + localPort := listen.Addr().(*net.TCPAddr).Port remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) if err != nil { @@ -77,7 +81,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { dst := fmt.Sprintf("%s@localhost", sshUser) cmd := codespaces.NewRemoteCommand( - ctx, localSSHPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), + ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) // Error channels are buffered so that neither sending goroutine gets stuck. @@ -85,7 +89,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { tunnelClosed := make(chan error, 1) go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil + tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is non-nil }() cmdDone := make(chan error, 1) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 800803269..3e403294a 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -277,11 +277,18 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro defer cancel() for _, pair := range portPairs { pair := pair - log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) - name := fmt.Sprintf("share-%d", pair.remote) + go func() { + listen, err := liveshare.Listen(pair.local) + if err != nil { + errc <- err + return + } + defer listen.Close() + log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) + name := fmt.Sprintf("share-%d", pair.remote) fwd := liveshare.NewPortForwarder(session, name, pair.remote) - errc <- fwd.ForwardToLocalPort(ctx, pair.local) // error always non-nil + errc <- fwd.ForwardToLocalPort(ctx, listen) // error always non-nil }() } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 183019504..2637dab99 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "net" "os" "strings" @@ -81,14 +82,15 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo } log.Print("\n") - usingCustomPort := true - if localSSHServerPort == 0 { - usingCustomPort = false // suppress log of command line in Shell - localSSHServerPort, err = codespaces.UnusedPort() - if err != nil { - return err - } + usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell + + // Ensure local port is listening before client (Shell) connects. + listen, err := liveshare.Listen(localSSHServerPort) + if err != nil { + return err } + defer listen.Close() + localSSHServerPort = listen.Addr().(*net.TCPAddr).Port connectDestination := sshProfile if connectDestination == "" { @@ -98,7 +100,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo tunnelClosed := make(chan error) go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHServerPort) // error is always non-nil + tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is always non-nil }() shellClosed := make(chan error) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 1ef2b819f..14dbfbb88 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "os" "os/exec" "strconv" @@ -13,25 +12,6 @@ import ( "github.com/github/go-liveshare" ) -// UnusedPort returns the number of a local TCP port that is currently -// unbound, or an error if none was available. -// -// Use of this function carries an inherent risk of a time-of-check to -// time-of-use race against other processes. -func UnusedPort() (int, error) { - addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - if err != nil { - return 0, fmt.Errorf("internal error while choosing port: %v", err) - } - - l, err := net.ListenTCP("tcp", addr) - if err != nil { - return 0, fmt.Errorf("choosing available port: %v", err) - } - defer l.Close() - return l.Addr().(*net.TCPAddr).Port, nil -} - // StartSSHServer installs (if necessary) and starts the SSH in the codespace. // It returns the remote port where it is running, the user to log in with, or an error if something failed. func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 271674e5f..99a713ba8 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "net" "strings" "time" @@ -46,10 +47,12 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connect to Live Share: %v", err) } - localSSHPort, err := UnusedPort() + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := liveshare.Listen(0) if err != nil { return err } + localPort := listen.Addr().(*net.TCPAddr).Port remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log) if err != nil { @@ -59,7 +62,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil + tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is non-nil }() t := time.NewTicker(1 * time.Second) @@ -74,7 +77,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connection failed: %v", err) case <-t.C: - states, err := getPostCreateOutput(ctx, localSSHPort, codespace, sshUser) + states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser) if err != nil { return fmt.Errorf("get post create output: %v", err) }