diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 49acb3449..f069a58ab 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "net" "os" "github.com/github/ghcs/api" @@ -10,6 +11,7 @@ import ( "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) func newLogsCmd() *cobra.Command { @@ -60,10 +62,13 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { return fmt.Errorf("connecting to Live Share: %v", err) } - localSSHPort, err := codespaces.UnusedPort() + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := net.Listen("tcp", ":0") // arbitrary port if err != nil { return err } + defer listen.Close() + localPort := listen.Addr().(*net.TCPAddr).Port remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) if err != nil { @@ -77,30 +82,15 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { dst := fmt.Sprintf("%s@localhost", sshUser) cmd := codespaces.NewRemoteCommand( - ctx, localSSHPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), + ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) - // Error channels are buffered so that neither sending goroutine gets stuck. - - tunnelClosed := make(chan error, 1) - go func() { + group, ctx := errgroup.WithContext(ctx) + group.Go(func() error { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil - }() - - cmdDone := make(chan error, 1) - go func() { - cmdDone <- cmd.Run() - }() - - select { - case err := <-tunnelClosed: + err := fwd.ForwardToListener(ctx, listen) // error is non-nil return fmt.Errorf("connection closed: %v", err) - - case err := <-cmdDone: - if err != nil { - return fmt.Errorf("error retrieving logs: %v", err) - } - return nil // success - } + }) + group.Go(cmd.Run) + return group.Wait() } diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index dd8f70131..4258991b6 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "os" "strconv" "strings" @@ -16,6 +17,7 @@ import ( "github.com/github/go-liveshare" "github.com/muhammadmuzzammil1998/jsonc" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) // portOptions represents the options accepted by the ports command. @@ -272,20 +274,22 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro // Run forwarding of all ports concurrently, aborting all of // them at the first failure, including cancellation of the context. - errc := make(chan error, len(portPairs)) - ctx, cancel := context.WithCancel(ctx) - defer cancel() + group, ctx := errgroup.WithContext(ctx) for _, pair := range portPairs { pair := pair - log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) - name := fmt.Sprintf("share-%d", pair.remote) - go func() { + group.Go(func() error { + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local)) + if err != nil { + return err + } + defer listen.Close() + log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) + name := fmt.Sprintf("share-%d", pair.remote) fwd := liveshare.NewPortForwarder(session, name, pair.remote) - errc <- fwd.ForwardToLocalPort(ctx, pair.local) // error always non-nil - }() + return fwd.ForwardToListener(ctx, listen) // error always non-nil + }) } - - return <-errc // first error + return group.Wait() // first error } type portPair struct { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 183019504..6e2724e73 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "net" "os" "strings" @@ -12,6 +13,7 @@ import ( "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) func newSSHCmd() *cobra.Command { @@ -81,42 +83,35 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo } log.Print("\n") - usingCustomPort := true - if localSSHServerPort == 0 { - usingCustomPort = false // suppress log of command line in Shell - localSSHServerPort, err = codespaces.UnusedPort() - if err != nil { - return err - } + usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell + + // Ensure local port is listening before client (Shell) connects. + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localSSHServerPort)) + if err != nil { + return err } + defer listen.Close() + localSSHServerPort = listen.Addr().(*net.TCPAddr).Port connectDestination := sshProfile if connectDestination == "" { connectDestination = fmt.Sprintf("%s@localhost", sshUser) } - tunnelClosed := make(chan error) - go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHServerPort) // error is always non-nil - }() - - shellClosed := make(chan error) - go func() { - shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort) - }() - log.Println("Ready...") - select { - case err := <-tunnelClosed: + group, ctx := errgroup.WithContext(ctx) + group.Go(func() error { + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + err := fwd.ForwardToListener(ctx, listen) // always non-nil return fmt.Errorf("tunnel closed: %v", err) - - case err := <-shellClosed: - if err != nil { + }) + group.Go(func() error { + if err := codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort); err != nil { return fmt.Errorf("shell closed: %v", err) } return nil // success - } + }) + return group.Wait() } func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 1ef2b819f..14dbfbb88 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "os" "os/exec" "strconv" @@ -13,25 +12,6 @@ import ( "github.com/github/go-liveshare" ) -// UnusedPort returns the number of a local TCP port that is currently -// unbound, or an error if none was available. -// -// Use of this function carries an inherent risk of a time-of-check to -// time-of-use race against other processes. -func UnusedPort() (int, error) { - addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - if err != nil { - return 0, fmt.Errorf("internal error while choosing port: %v", err) - } - - l, err := net.ListenTCP("tcp", addr) - if err != nil { - return 0, fmt.Errorf("choosing available port: %v", err) - } - defer l.Close() - return l.Addr().(*net.TCPAddr).Port, nil -} - // StartSSHServer installs (if necessary) and starts the SSH in the codespace. // It returns the remote port where it is running, the user to log in with, or an error if something failed. func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 271674e5f..492ce3964 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "net" "strings" "time" @@ -46,10 +47,12 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connect to Live Share: %v", err) } - localSSHPort, err := UnusedPort() + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := net.Listen("tcp", ":0") // arbitrary port if err != nil { return err } + localPort := listen.Addr().(*net.TCPAddr).Port remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log) if err != nil { @@ -59,7 +62,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil + tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil }() t := time.NewTicker(1 * time.Second) @@ -74,7 +77,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connection failed: %v", err) case <-t.C: - states, err := getPostCreateOutput(ctx, localSSHPort, codespace, sshUser) + states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser) if err != nil { return fmt.Errorf("get post create output: %v", err) }