From 2b95cbc5a6933154a2c4005ce7df431fe7566672 Mon Sep 17 00:00:00 2001 From: Caleb Brose <5447118+cmbrose@users.noreply.github.com> Date: Thu, 19 Jan 2023 20:22:24 -0600 Subject: [PATCH 1/2] Close port forward writer on reader --- internal/codespaces/codespaces.go | 16 ++++++++++++ internal/codespaces/rpc/invoker.go | 19 +++++++++++--- internal/codespaces/states.go | 4 +-- pkg/cmd/codespace/jupyter.go | 3 ++- pkg/cmd/codespace/logs.go | 4 +-- pkg/cmd/codespace/ports.go | 3 +-- pkg/cmd/codespace/ssh.go | 25 ++---------------- pkg/liveshare/port_forwarder.go | 39 ++++++++++++++++++++++++---- pkg/liveshare/port_forwarder_test.go | 6 ++++- 9 files changed, 78 insertions(+), 41 deletions(-) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 2dc81ba64..8fb096b06 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "time" "github.com/cenkalti/backoff/v4" @@ -79,3 +80,18 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session Logger: sessionLogger, }) } + +// ListenTCP starts a localhost tcp listener and returns the listener and bound port +func ListenTCP(port int) (*net.TCPListener, int, error) { + addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + return nil, 0, fmt.Errorf("failed to build tcp address: %w", err) + } + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, 0, fmt.Errorf("failed to listen to local port over tcp: %w", err) + } + port = listener.Addr().(*net.TCPAddr).Port + + return listener, port, nil +} diff --git a/internal/codespaces/rpc/invoker.go b/internal/codespaces/rpc/invoker.go index 67a88bb2f..d22aad4e1 100644 --- a/internal/codespaces/rpc/invoker.go +++ b/internal/codespaces/rpc/invoker.go @@ -68,11 +68,11 @@ func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession) (Inv // Finds a free port to listen on and creates a new RPC invoker that connects to that port func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, error) { - listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0)) + listener, err := listenTCP() if err != nil { - return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err) + return nil, err } - localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port) + localAddress := listener.Addr().String() invoker := &invoker{ session: session, @@ -229,3 +229,16 @@ func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options StartSS return port, response.User, nil } + +func listenTCP() (*net.TCPListener, error) { + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("failed to build tcp address: %w", err) + } + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err) + } + + return listener, nil +} diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 58be127be..9874d1a62 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "log" - "net" "time" "github.com/cli/cli/v2/internal/codespaces/api" @@ -53,11 +52,10 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl }() // Ensure local port is listening before client (getPostCreateOutput) connects. - listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port + listen, localPort, err := ListenTCP(0) if err != nil { return err } - localPort := listen.Addr().(*net.TCPAddr).Port progress.StartProgressIndicatorWithLabel("Fetching SSH Details") invoker, err := rpc.CreateInvoker(ctx, session) diff --git a/pkg/cmd/codespace/jupyter.go b/pkg/cmd/codespace/jupyter.go index d60758494..0e3e0dee0 100644 --- a/pkg/cmd/codespace/jupyter.go +++ b/pkg/cmd/codespace/jupyter.go @@ -6,6 +6,7 @@ import ( "net" "strings" + "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/rpc" "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" @@ -58,7 +59,7 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) { a.StopProgressIndicator() // Pass 0 to pick a random port - listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0)) + listen, _, err := codespaces.ListenTCP(0) if err != nil { return err } diff --git a/pkg/cmd/codespace/logs.go b/pkg/cmd/codespace/logs.go index df080ef3b..9a42b9866 100644 --- a/pkg/cmd/codespace/logs.go +++ b/pkg/cmd/codespace/logs.go @@ -3,7 +3,6 @@ package codespace import ( "context" "fmt" - "net" "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/rpc" @@ -49,12 +48,11 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err defer safeClose(session, &err) // Ensure local port is listening before client (getPostCreateOutput) connects. - listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port + listen, localPort, err := codespaces.ListenTCP(0) if err != nil { return err } defer listen.Close() - localPort := listen.Addr().(*net.TCPAddr).Port a.StartProgressIndicatorWithLabel("Fetching SSH Details") invoker, err := rpc.CreateInvoker(ctx, session) diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index 47fd4d979..d36fa54a1 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "net" "net/http" "strconv" "strings" @@ -390,7 +389,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st for _, pair := range portPairs { pair := pair group.Go(func() error { - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local)) + listen, _, err := codespaces.ListenTCP(pair.local) if err != nil { return err } diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 3a3fdc86a..d8788b45a 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -6,9 +6,7 @@ import ( "context" "errors" "fmt" - "io" "log" - "net" "os" "os/exec" "path" @@ -188,7 +186,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e if opts.stdio { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true) - stdio := newReadWriteCloser(os.Stdin, os.Stdout) + stdio := liveshare.NewReadWriteHalfCloser(os.Stdin, os.Stdout) err := fwd.Forward(ctx, stdio) // always non-nil return fmt.Errorf("tunnel closed: %w", err) } @@ -199,12 +197,11 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e // Ensure local port is listening before client (Shell) connects. // Unless the user specifies a server port, localSSHServerPort is 0 // and thus the client will pick a random port. - listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", localSSHServerPort)) + listen, localSSHServerPort, err := codespaces.ListenTCP(localSSHServerPort) if err != nil { return err } defer listen.Close() - localSSHServerPort = listen.Addr().(*net.TCPAddr).Port connectDestination := opts.profile if connectDestination == "" { @@ -745,21 +742,3 @@ func (fl *fileLogger) Name() string { func (fl *fileLogger) Close() error { return fl.f.Close() } - -type combinedReadWriteCloser struct { - io.ReadCloser - io.WriteCloser -} - -func newReadWriteCloser(reader io.ReadCloser, writer io.WriteCloser) io.ReadWriteCloser { - return &combinedReadWriteCloser{reader, writer} -} - -func (crwc *combinedReadWriteCloser) Close() error { - werr := crwc.WriteCloser.Close() - rerr := crwc.ReadCloser.Close() - if werr != nil { - return werr - } - return rerr -} diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go index f042eeaea..9b47633fc 100644 --- a/pkg/liveshare/port_forwarder.go +++ b/pkg/liveshare/port_forwarder.go @@ -16,6 +16,33 @@ type portForwardingSession interface { KeepAlive(string) } +type ReadWriteHalfCloser interface { + io.ReadWriteCloser + CloseWrite() error +} + +type combinedReadWriteHalfCloser struct { + io.ReadCloser + io.WriteCloser +} + +func NewReadWriteHalfCloser(reader io.ReadCloser, writer io.WriteCloser) ReadWriteHalfCloser { + return &combinedReadWriteHalfCloser{reader, writer} +} + +func (crwc *combinedReadWriteHalfCloser) Close() error { + werr := crwc.WriteCloser.Close() + rerr := crwc.ReadCloser.Close() + if werr != nil { + return werr + } + return rerr +} + +func (crwc *combinedReadWriteHalfCloser) CloseWrite() error { + return crwc.WriteCloser.Close() +} + // A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote // container to a local destination such as a network port or Go reader/writer. type PortForwarder struct { @@ -48,7 +75,7 @@ func NewPortForwarder(session portForwardingSession, name string, remotePort int // until it encounters the first error, which may include context // cancellation. Its error result is always non-nil. The caller is // responsible for closing the listening port. -func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) { +func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen *net.TCPListener) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { return err @@ -65,7 +92,7 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List } go func() { for { - conn, err := listen.Accept() + conn, err := listen.AcceptTCP() if err != nil { sendError(err) return @@ -84,7 +111,7 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List // Forward forwards traffic between the container's remote port and // the specified read/write stream. On return, the stream is closed. -func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error { +func (fwd *PortForwarder) Forward(ctx context.Context, conn ReadWriteHalfCloser) error { id, err := fwd.shareRemotePort(ctx) if err != nil { conn.Close() @@ -143,7 +170,7 @@ func (t *trafficMonitor) Read(p []byte) (n int, err error) { } // handleConnection handles forwarding for a single accepted connection, then closes it. -func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn io.ReadWriteCloser) (err error) { +func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn ReadWriteHalfCloser) (err error) { span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection") defer span.Finish() @@ -165,9 +192,11 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, co // bi-directional copy of data. errs := make(chan error, 2) - copyConn := func(w io.Writer, r io.Reader) { + copyConn := func(w ReadWriteHalfCloser, r io.Reader) { _, err := io.Copy(w, r) errs <- err + + w.CloseWrite() } var ( diff --git a/pkg/liveshare/port_forwarder_test.go b/pkg/liveshare/port_forwarder_test.go index b02165849..61acde368 100644 --- a/pkg/liveshare/port_forwarder_test.go +++ b/pkg/liveshare/port_forwarder_test.go @@ -71,6 +71,10 @@ func TestPortForwarderStart(t *testing.T) { t.Fatal(err) } defer listen.Close() + tcpListener, ok := listen.(*net.TCPListener) + if !ok { + t.Fatal("net.Listen did not return a TCPListener") + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -82,7 +86,7 @@ func TestPortForwarderStart(t *testing.T) { done := make(chan error, 2) go func() { - done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, listen) + done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, tcpListener) }() go func() { From 21c9e7c6dba02ad97179bb987477ad9abe12dad2 Mon Sep 17 00:00:00 2001 From: Caleb Brose <5447118+cmbrose@users.noreply.github.com> Date: Thu, 19 Jan 2023 20:44:38 -0600 Subject: [PATCH 2/2] Linter and comment --- pkg/liveshare/port_forwarder.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go index 9b47633fc..5f2742209 100644 --- a/pkg/liveshare/port_forwarder.go +++ b/pkg/liveshare/port_forwarder.go @@ -196,7 +196,8 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, co _, err := io.Copy(w, r) errs <- err - w.CloseWrite() + // Ignore errors here, we call the full Close() later and catch that error + _ = w.CloseWrite() } var (