diff --git a/port_forwarder.go b/port_forwarder.go index e6eedf16c..cc7b6ea1d 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -5,77 +5,116 @@ import ( "fmt" "io" "net" - "strconv" ) -// A PortForwader can forward ports from a remote liveshare host to localhost +// A PortForwarder forwards TCP traffic between a port on a remote +// LiveShare host and a local port. type PortForwarder struct { client *Client server *Server port int - errCh chan error } -// NewPortForwarder creates a new PortForwader with a given client, server and port +// NewPortForwarder creates a new PortForwarder that connects a given client, server and port. func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { return &PortForwarder{ client: client, server: server, port: port, - errCh: make(chan error), } } -// Start is a method to start forwarding the server to a localhost port -func (l *PortForwarder) Start(ctx context.Context) error { - ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) +// Forward enables port forwarding. It accepts and handles TCP +// connections until it encounters the first error, which may include +// context cancellation. Its result is non-nil. +func (l *PortForwarder) Forward(ctx context.Context) (err error) { + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port)) if err != nil { - return fmt.Errorf("error listening on tcp port: %v", err) + return fmt.Errorf("error listening on TCP port: %v", err) } + defer safeClose(listen, &err) + errc := make(chan error, 1) + sendError := func(err error) { + // Use non-blocking send, to avoid goroutines getting + // stuck in case of concurrent or sequential errors. + select { + case errc <- err: + default: + } + } go func() { for { - conn, err := ln.Accept() + conn, err := listen.Accept() if err != nil { - l.errCh <- fmt.Errorf("error accepting incoming connection: %v", err) + sendError(err) + return } - go l.handleConnection(ctx, conn) + go func() { + if err := l.handleConnection(ctx, conn); err != nil { + sendError(err) + } + }() } }() + return awaitError(ctx, errc) +} + +// ForwardWithConn handles port forwarding for a single connection. +func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { + // Create buffered channel so that send doesn't get stuck after context cancellation. + errc := make(chan error, 1) + go func() { + if err := l.handleConnection(ctx, conn); err != nil { + errc <- err + } + }() + return awaitError(ctx, errc) +} + +func awaitError(ctx context.Context, errc <-chan error) error { select { - case err := <-l.errCh: + case err := <-errc: return err case <-ctx.Done(): - return ln.Close() + return ctx.Err() // canceled } +} +// handleConnection handles forwarding for a single accepted connection, then closes it. +func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) { + defer safeClose(conn, &err) + + channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) + if err != nil { + return fmt.Errorf("error opening streaming channel for new connection: %v", err) + } + defer safeClose(channel, &err) + + errs := make(chan error, 2) + copyConn := func(w io.Writer, r io.Reader) { + _, err := io.Copy(w, r) + errs <- err + } + go copyConn(conn, channel) + go copyConn(channel, conn) + + // await result + for i := 0; i < 2; i++ { + if err := <-errs; err != nil && err != io.EOF { + return fmt.Errorf("tunnel connection: %v", err) + } + } return nil } -func (l *PortForwarder) StartWithConn(ctx context.Context, conn io.ReadWriteCloser) error { - go l.handleConnection(ctx, conn) - return <-l.errCh -} - -func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) { - channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) - if err != nil { - l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err) - return +// safeClose reports the error (to *err) from closing the stream only +// if no other error was previously reported. +func safeClose(closer io.Closer, err *error) { + closeErr := closer.Close() + if *err == nil { + *err = closeErr } - - copyConn := func(writer io.Writer, reader io.Reader) { - if _, err := io.Copy(writer, reader); err != nil { - channel.Close() - conn.Close() - if err != io.EOF { - l.errCh <- fmt.Errorf("tunnel connection: %v", err) - } - } - } - - go copyConn(conn, channel) - go copyConn(channel, conn) } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 33a33b39b..3ae846937 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -63,10 +63,7 @@ func TestPortForwarderStart(t *testing.T) { if err := server.StartSharing(ctx, "http", 8000); err != nil { done <- fmt.Errorf("start sharing: %v", err) } - if err := pf.Start(ctx); err != nil { - done <- err - } - done <- nil + done <- pf.Forward(ctx) }() go func() {