diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 8a83b252a..49369d1c0 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "net" "os" "github.com/github/ghcs/api" @@ -10,6 +11,7 @@ import ( "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) func newLogsCmd() *cobra.Command { @@ -71,10 +73,13 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow return fmt.Errorf("connecting to Live Share: %v", err) } - localSSHPort, err := codespaces.UnusedPort() + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := net.Listen("tcp", ":0") // arbitrary port if err != nil { return err } + defer listen.Close() + localPort := listen.Addr().(*net.TCPAddr).Port remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) if err != nil { @@ -88,30 +93,15 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow dst := fmt.Sprintf("%s@localhost", sshUser) cmd := codespaces.NewRemoteCommand( - ctx, localSSHPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), + ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) - // Error channels are buffered so that neither sending goroutine gets stuck. - - tunnelClosed := make(chan error, 1) - go func() { + group, ctx := errgroup.WithContext(ctx) + group.Go(func() error { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil - }() - - cmdDone := make(chan error, 1) - go func() { - cmdDone <- cmd.Run() - }() - - select { - case err := <-tunnelClosed: + err := fwd.ForwardToListener(ctx, listen) // error is non-nil return fmt.Errorf("connection closed: %v", err) - - case err := <-cmdDone: - if err != nil { - return fmt.Errorf("error retrieving logs: %v", err) - } - return nil // success - } + }) + group.Go(cmd.Run) + return group.Wait() } diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 2f1515ac0..3e861dacb 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -19,8 +19,9 @@ func main() { var version = "DEV" var rootCmd = &cobra.Command{ - Use: "ghcs", - SilenceUsage: true, // don't print usage message after each error (see #80) + Use: "ghcs", + SilenceUsage: true, // don't print usage message after each error (see #80) + SilenceErrors: false, // print errors automatically so that main need not Long: `Unofficial CLI tool to manage GitHub Codespaces. Running commands requires the GITHUB_TOKEN environment variable to be set to a @@ -43,5 +44,4 @@ func explainError(w io.Writer, err error) { fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") return } - fmt.Fprintf(w, "%v\n", err) } diff --git a/cmd/ghcs/output/logger.go b/cmd/ghcs/output/logger.go index 32d05acc8..a2aa68ba1 100644 --- a/cmd/ghcs/output/logger.go +++ b/cmd/ghcs/output/logger.go @@ -5,6 +5,8 @@ import ( "io" ) +// NewLogger returns a Logger that will write to the given stdout/stderr writers. +// Disable the Logger to prevent it from writing to stdout in a TTY environment. func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger { return &Logger{ out: stdout, @@ -13,12 +15,16 @@ func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger { } } +// Logger writes to the given stdout/stderr writers. +// If not enabled, Print functions will noop but Error functions will continue +// to write to the stderr writer. type Logger struct { out io.Writer errout io.Writer enabled bool } +// Print writes the arguments to the stdout writer. func (l *Logger) Print(v ...interface{}) (int, error) { if !l.enabled { return 0, nil @@ -26,6 +32,7 @@ func (l *Logger) Print(v ...interface{}) (int, error) { return fmt.Fprint(l.out, v...) } +// Println writes the arguments to the stdout writer with a newline at the end. func (l *Logger) Println(v ...interface{}) (int, error) { if !l.enabled { return 0, nil @@ -33,6 +40,7 @@ func (l *Logger) Println(v ...interface{}) (int, error) { return fmt.Fprintln(l.out, v...) } +// Printf writes the formatted arguments to the stdout writer. func (l *Logger) Printf(f string, v ...interface{}) (int, error) { if !l.enabled { return 0, nil @@ -40,6 +48,12 @@ func (l *Logger) Printf(f string, v ...interface{}) (int, error) { return fmt.Fprintf(l.out, f, v...) } +// Errorf writes the formatted arguments to the stderr writer. func (l *Logger) Errorf(f string, v ...interface{}) (int, error) { return fmt.Fprintf(l.errout, f, v...) } + +// Errorln writes the arguments to the stderr writer with a newline at the end. +func (l *Logger) Errorln(v ...interface{}) (int, error) { + return fmt.Fprintln(l.errout, v...) +} diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index dd8f70131..4258991b6 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "os" "strconv" "strings" @@ -16,6 +17,7 @@ import ( "github.com/github/go-liveshare" "github.com/muhammadmuzzammil1998/jsonc" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) // portOptions represents the options accepted by the ports command. @@ -272,20 +274,22 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro // Run forwarding of all ports concurrently, aborting all of // them at the first failure, including cancellation of the context. - errc := make(chan error, len(portPairs)) - ctx, cancel := context.WithCancel(ctx) - defer cancel() + group, ctx := errgroup.WithContext(ctx) for _, pair := range portPairs { pair := pair - log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) - name := fmt.Sprintf("share-%d", pair.remote) - go func() { + group.Go(func() error { + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local)) + if err != nil { + return err + } + defer listen.Close() + log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) + name := fmt.Sprintf("share-%d", pair.remote) fwd := liveshare.NewPortForwarder(session, name, pair.remote) - errc <- fwd.ForwardToLocalPort(ctx, pair.local) // error always non-nil - }() + return fwd.ForwardToListener(ctx, listen) // error always non-nil + }) } - - return <-errc // first error + return group.Wait() // first error } type portPair struct { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 183019504..6e2724e73 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "net" "os" "strings" @@ -12,6 +13,7 @@ import ( "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) func newSSHCmd() *cobra.Command { @@ -81,42 +83,35 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo } log.Print("\n") - usingCustomPort := true - if localSSHServerPort == 0 { - usingCustomPort = false // suppress log of command line in Shell - localSSHServerPort, err = codespaces.UnusedPort() - if err != nil { - return err - } + usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell + + // Ensure local port is listening before client (Shell) connects. + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localSSHServerPort)) + if err != nil { + return err } + defer listen.Close() + localSSHServerPort = listen.Addr().(*net.TCPAddr).Port connectDestination := sshProfile if connectDestination == "" { connectDestination = fmt.Sprintf("%s@localhost", sshUser) } - tunnelClosed := make(chan error) - go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHServerPort) // error is always non-nil - }() - - shellClosed := make(chan error) - go func() { - shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort) - }() - log.Println("Ready...") - select { - case err := <-tunnelClosed: + group, ctx := errgroup.WithContext(ctx) + group.Go(func() error { + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + err := fwd.ForwardToListener(ctx, listen) // always non-nil return fmt.Errorf("tunnel closed: %v", err) - - case err := <-shellClosed: - if err != nil { + }) + group.Go(func() error { + if err := codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort); err != nil { return fmt.Errorf("shell closed: %v", err) } return nil // success - } + }) + return group.Wait() } func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 1ef2b819f..14dbfbb88 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "os" "os/exec" "strconv" @@ -13,25 +12,6 @@ import ( "github.com/github/go-liveshare" ) -// UnusedPort returns the number of a local TCP port that is currently -// unbound, or an error if none was available. -// -// Use of this function carries an inherent risk of a time-of-check to -// time-of-use race against other processes. -func UnusedPort() (int, error) { - addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - if err != nil { - return 0, fmt.Errorf("internal error while choosing port: %v", err) - } - - l, err := net.ListenTCP("tcp", addr) - if err != nil { - return 0, fmt.Errorf("choosing available port: %v", err) - } - defer l.Close() - return l.Addr().(*net.TCPAddr).Port, nil -} - // StartSSHServer installs (if necessary) and starts the SSH in the codespace. // It returns the remote port where it is running, the user to log in with, or an error if something failed. func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 271674e5f..492ce3964 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "net" "strings" "time" @@ -46,10 +47,12 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connect to Live Share: %v", err) } - localSSHPort, err := UnusedPort() + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := net.Listen("tcp", ":0") // arbitrary port if err != nil { return err } + localPort := listen.Addr().(*net.TCPAddr).Port remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log) if err != nil { @@ -59,7 +62,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil + tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil }() t := time.NewTicker(1 * time.Second) @@ -74,7 +77,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connection failed: %v", err) case <-t.C: - states, err := getPostCreateOutput(ctx, localSSHPort, codespace, sshUser) + states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser) if err != nil { return fmt.Errorf("get post create output: %v", err) }