From 981b2545bc91e6f190c8ac8b8152a7c7cb0695a5 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 17:04:07 -0400 Subject: [PATCH 1/4] sketch of changes for https://github.com/github/go-liveshare/pull/13 --- cmd/ghcs/logs.go | 10 +++++++--- cmd/ghcs/ports.go | 13 ++++++++++--- cmd/ghcs/ssh.go | 18 ++++++++++-------- internal/codespaces/ssh.go | 20 -------------------- internal/codespaces/states.go | 9 ++++++--- 5 files changed, 33 insertions(+), 37 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 49acb3449..590596603 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "net" "os" "github.com/github/ghcs/api" @@ -60,10 +61,13 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { return fmt.Errorf("connecting to Live Share: %v", err) } - localSSHPort, err := codespaces.UnusedPort() + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := liveshare.Listen(0) // zero => arbitrary if err != nil { return err } + defer listen.Close() + localPort := listen.Addr().(*net.TCPAddr).Port remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) if err != nil { @@ -77,7 +81,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { dst := fmt.Sprintf("%s@localhost", sshUser) cmd := codespaces.NewRemoteCommand( - ctx, localSSHPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), + ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) // Error channels are buffered so that neither sending goroutine gets stuck. @@ -85,7 +89,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { tunnelClosed := make(chan error, 1) go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil + tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is non-nil }() cmdDone := make(chan error, 1) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 800803269..3e403294a 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -277,11 +277,18 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro defer cancel() for _, pair := range portPairs { pair := pair - log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) - name := fmt.Sprintf("share-%d", pair.remote) + go func() { + listen, err := liveshare.Listen(pair.local) + if err != nil { + errc <- err + return + } + defer listen.Close() + log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) + name := fmt.Sprintf("share-%d", pair.remote) fwd := liveshare.NewPortForwarder(session, name, pair.remote) - errc <- fwd.ForwardToLocalPort(ctx, pair.local) // error always non-nil + errc <- fwd.ForwardToLocalPort(ctx, listen) // error always non-nil }() } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 183019504..2637dab99 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "net" "os" "strings" @@ -81,14 +82,15 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo } log.Print("\n") - usingCustomPort := true - if localSSHServerPort == 0 { - usingCustomPort = false // suppress log of command line in Shell - localSSHServerPort, err = codespaces.UnusedPort() - if err != nil { - return err - } + usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell + + // Ensure local port is listening before client (Shell) connects. + listen, err := liveshare.Listen(localSSHServerPort) + if err != nil { + return err } + defer listen.Close() + localSSHServerPort = listen.Addr().(*net.TCPAddr).Port connectDestination := sshProfile if connectDestination == "" { @@ -98,7 +100,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo tunnelClosed := make(chan error) go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHServerPort) // error is always non-nil + tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is always non-nil }() shellClosed := make(chan error) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 1ef2b819f..14dbfbb88 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "os" "os/exec" "strconv" @@ -13,25 +12,6 @@ import ( "github.com/github/go-liveshare" ) -// UnusedPort returns the number of a local TCP port that is currently -// unbound, or an error if none was available. -// -// Use of this function carries an inherent risk of a time-of-check to -// time-of-use race against other processes. -func UnusedPort() (int, error) { - addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - if err != nil { - return 0, fmt.Errorf("internal error while choosing port: %v", err) - } - - l, err := net.ListenTCP("tcp", addr) - if err != nil { - return 0, fmt.Errorf("choosing available port: %v", err) - } - defer l.Close() - return l.Addr().(*net.TCPAddr).Port, nil -} - // StartSSHServer installs (if necessary) and starts the SSH in the codespace. // It returns the remote port where it is running, the user to log in with, or an error if something failed. func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 271674e5f..99a713ba8 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "net" "strings" "time" @@ -46,10 +47,12 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connect to Live Share: %v", err) } - localSSHPort, err := UnusedPort() + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := liveshare.Listen(0) if err != nil { return err } + localPort := listen.Addr().(*net.TCPAddr).Port remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log) if err != nil { @@ -59,7 +62,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil + tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is non-nil }() t := time.NewTicker(1 * time.Second) @@ -74,7 +77,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connection failed: %v", err) case <-t.C: - states, err := getPostCreateOutput(ctx, localSSHPort, codespace, sshUser) + states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser) if err != nil { return fmt.Errorf("get post create output: %v", err) } From 43198b24aa6c5342dba92cbe514ab30f9dea05ac Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 3 Sep 2021 12:50:11 -0400 Subject: [PATCH 2/4] use errgroup --- cmd/ghcs/logs.go | 28 +++++++--------------------- cmd/ghcs/ports.go | 18 +++++++----------- cmd/ghcs/ssh.go | 27 ++++++++++----------------- 3 files changed, 24 insertions(+), 49 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index d3b2c063f..5e7e8c0a5 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -11,6 +11,7 @@ import ( "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) func newLogsCmd() *cobra.Command { @@ -84,27 +85,12 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) - // Error channels are buffered so that neither sending goroutine gets stuck. - - tunnelClosed := make(chan error, 1) - go func() { + group, ctx := errgroup.WithContext(ctx) + group.Go(func() error { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil - }() - - cmdDone := make(chan error, 1) - go func() { - cmdDone <- cmd.Run() - }() - - select { - case err := <-tunnelClosed: + err := fwd.ForwardToListener(ctx, listen) // error is non-nil return fmt.Errorf("connection closed: %v", err) - - case err := <-cmdDone: - if err != nil { - return fmt.Errorf("error retrieving logs: %v", err) - } - return nil // success - } + }) + group.Go(cmd.Run) + return group.Wait() } diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 6c582c504..958b25996 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -16,6 +16,7 @@ import ( "github.com/github/go-liveshare" "github.com/muhammadmuzzammil1998/jsonc" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) // portOptions represents the options accepted by the ports command. @@ -272,27 +273,22 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro // Run forwarding of all ports concurrently, aborting all of // them at the first failure, including cancellation of the context. - errc := make(chan error, len(portPairs)) - ctx, cancel := context.WithCancel(ctx) - defer cancel() + group, ctx := errgroup.WithContext(ctx) for _, pair := range portPairs { pair := pair - - go func() { + group.Go(func() error { listen, err := liveshare.ListenTCP(pair.local) if err != nil { - errc <- err - return + return nil } defer listen.Close() log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) name := fmt.Sprintf("share-%d", pair.remote) fwd := liveshare.NewPortForwarder(session, name, pair.remote) - errc <- fwd.ForwardToListener(ctx, listen) // error always non-nil - }() + return fwd.ForwardToListener(ctx, listen) // error always non-nil + }) } - - return <-errc // first error + return group.Wait() // first error } type portPair struct { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index d7c0847e7..55a406c94 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -13,6 +13,7 @@ import ( "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) func newSSHCmd() *cobra.Command { @@ -97,28 +98,20 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo connectDestination = fmt.Sprintf("%s@localhost", sshUser) } - tunnelClosed := make(chan error) - go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is always non-nil - }() - - shellClosed := make(chan error) - go func() { - shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort) - }() - log.Println("Ready...") - select { - case err := <-tunnelClosed: + group, ctx := errgroup.WithContext(ctx) + group.Go(func() error { + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + err := fwd.ForwardToListener(ctx, listen) // always non-nil return fmt.Errorf("tunnel closed: %v", err) - - case err := <-shellClosed: - if err != nil { + }) + group.Go(func() error { + if err := codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort); err != nil { return fmt.Errorf("shell closed: %v", err) } return nil // success - } + }) + return group.Wait() } func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) { From 2c660fa2e5a47c499f74aeb7dc522349a5753d3a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 3 Sep 2021 12:55:40 -0400 Subject: [PATCH 3/4] avoid ListenTCP helper --- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 3 ++- cmd/ghcs/ssh.go | 2 +- internal/codespaces/states.go | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 5e7e8c0a5..f069a58ab 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -63,7 +63,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { } // Ensure local port is listening before client (getPostCreateOutput) connects. - listen, err := liveshare.ListenTCP(0) // zero => arbitrary + listen, err := net.Listen("tcp", ":0") // arbitrary port if err != nil { return err } diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 958b25996..fb76022d7 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "os" "strconv" "strings" @@ -277,7 +278,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro for _, pair := range portPairs { pair := pair group.Go(func() error { - listen, err := liveshare.ListenTCP(pair.local) + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local)) if err != nil { return nil } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 55a406c94..6e2724e73 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -86,7 +86,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell // Ensure local port is listening before client (Shell) connects. - listen, err := liveshare.ListenTCP(localSSHServerPort) + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localSSHServerPort)) if err != nil { return err } diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 46d4f5ed5..492ce3964 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -48,7 +48,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u } // Ensure local port is listening before client (getPostCreateOutput) connects. - listen, err := liveshare.ListenTCP(0) + listen, err := net.Listen("tcp", ":0") // arbitrary port if err != nil { return err } From 9e81dc7fdef457f09a18bd81a231b34a3e8f7d03 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 3 Sep 2021 12:56:47 -0400 Subject: [PATCH 4/4] fix missing error return --- cmd/ghcs/ports.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index fb76022d7..4258991b6 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -280,7 +280,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro group.Go(func() error { listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local)) if err != nil { - return nil + return err } defer listen.Close() log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)