cli/port_forwarder.go
2021-09-02 11:06:58 -04:00

117 lines
2.9 KiB
Go

package liveshare
import (
"context"
"fmt"
"io"
"net"
)
// A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session.
type PortForwarder struct {
session *Session
port int
}
// NewPortForwarder creates a new PortForwarder for a given Live Share session and local TCP port.
func NewPortForwarder(session *Session, port int) *PortForwarder {
return &PortForwarder{
session: session,
port: port,
}
}
// Forward enables port forwarding. It accepts and handles TCP
// connections until it encounters the first error, which may include
// context cancellation. Its result is non-nil.
func (l *PortForwarder) Forward(ctx context.Context) (err error) {
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port))
if err != nil {
return fmt.Errorf("error listening on TCP port: %v", err)
}
defer safeClose(listen, &err)
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting
// stuck in case of concurrent or sequential errors.
select {
case errc <- err:
default:
}
}
go func() {
for {
conn, err := listen.Accept()
if err != nil {
sendError(err)
return
}
go func() {
if err := l.handleConnection(ctx, conn); err != nil {
sendError(err)
}
}()
}
}()
return awaitError(ctx, errc)
}
// ForwardWithConn handles port forwarding for a single connection.
func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error {
// Create buffered channel so that send doesn't get stuck after context cancellation.
errc := make(chan error, 1)
go func() {
if err := l.handleConnection(ctx, conn); err != nil {
errc <- err
}
}()
return awaitError(ctx, errc)
}
func awaitError(ctx context.Context, errc <-chan error) error {
select {
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err() // canceled
}
}
// handleConnection handles forwarding for a single accepted connection, then closes it.
func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) {
defer safeClose(conn, &err)
channel, err := l.session.openStreamingChannel(ctx, l.session.streamName, l.session.streamCondition)
if err != nil {
return fmt.Errorf("error opening streaming channel for new connection: %v", err)
}
defer safeClose(channel, &err)
errs := make(chan error, 2)
copyConn := func(w io.Writer, r io.Reader) {
_, err := io.Copy(w, r)
errs <- err
}
go copyConn(conn, channel)
go copyConn(channel, conn)
// await result
for i := 0; i < 2; i++ {
if err := <-errs; err != nil && err != io.EOF {
return fmt.Errorf("tunnel connection: %v", err)
}
}
return nil
}
// safeClose reports the error (to *err) from closing the stream only
// if no other error was previously reported.
func safeClose(closer io.Closer, err *error) {
closeErr := closer.Close()
if *err == nil {
*err = closeErr
}
}