cli/pkg/liveshare/port_forwarder.go
2023-01-19 20:44:38 -06:00

241 lines
6.6 KiB
Go

package liveshare
import (
"context"
"fmt"
"io"
"net"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
)
type portForwardingSession interface {
StartSharing(context.Context, string, int) (ChannelID, error)
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
KeepAlive(string)
}
type ReadWriteHalfCloser interface {
io.ReadWriteCloser
CloseWrite() error
}
type combinedReadWriteHalfCloser struct {
io.ReadCloser
io.WriteCloser
}
func NewReadWriteHalfCloser(reader io.ReadCloser, writer io.WriteCloser) ReadWriteHalfCloser {
return &combinedReadWriteHalfCloser{reader, writer}
}
func (crwc *combinedReadWriteHalfCloser) Close() error {
werr := crwc.WriteCloser.Close()
rerr := crwc.ReadCloser.Close()
if werr != nil {
return werr
}
return rerr
}
func (crwc *combinedReadWriteHalfCloser) CloseWrite() error {
return crwc.WriteCloser.Close()
}
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
// container to a local destination such as a network port or Go reader/writer.
type PortForwarder struct {
session portForwardingSession
name string
remotePort int
keepAlive bool
}
// NewPortForwarder returns a new PortForwarder for the specified
// remote port and Live Share session. The name describes the purpose
// of the remote port or service. The keepAlive flag indicates whether
// the session should be kept alive with port forwarding traffic.
func NewPortForwarder(session portForwardingSession, name string, remotePort int, keepAlive bool) *PortForwarder {
return &PortForwarder{
session: session,
name: name,
remotePort: remotePort,
keepAlive: keepAlive,
}
}
// ForwardToListener forwards traffic between the container's remote
// port and a local port, which must already be listening for
// connections. (Accepting a listener rather than a port number avoids
// races against other processes opening ports, and against a client
// connecting to the socket prematurely.)
//
// ForwardToListener accepts and handles connections on the local port
// until it encounters the first error, which may include context
// cancellation. Its error result is always non-nil. The caller is
// responsible for closing the listening port.
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen *net.TCPListener) (err error) {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
return err
}
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting
// stuck in case of concurrent or sequential errors.
select {
case errc <- err:
default:
}
}
go func() {
for {
conn, err := listen.AcceptTCP()
if err != nil {
sendError(err)
return
}
go func() {
if err := fwd.handleConnection(ctx, id, conn); err != nil {
sendError(err)
}
}()
}
}()
return awaitError(ctx, errc)
}
// Forward forwards traffic between the container's remote port and
// the specified read/write stream. On return, the stream is closed.
func (fwd *PortForwarder) Forward(ctx context.Context, conn ReadWriteHalfCloser) error {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
conn.Close()
return err
}
// Create buffered channel so that send doesn't get stuck after context cancellation.
errc := make(chan error, 1)
go func() {
errc <- fwd.handleConnection(ctx, id, conn)
}()
return awaitError(ctx, errc)
}
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) {
id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort)
if err != nil {
err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err)
}
return id, err
}
func awaitError(ctx context.Context, errc <-chan error) error {
select {
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err() // canceled
}
}
type trafficMonitorSession interface {
KeepAlive(string)
}
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
// it of the traffic type during Read operations.
type trafficMonitor struct {
reader io.Reader
session trafficMonitorSession
trafficType string
}
// newTrafficMonitor returns a new trafficMonitor for the specified
// session and traffic type. It wraps the provided io.Reader with its own
// Read method.
func newTrafficMonitor(reader io.Reader, session trafficMonitorSession, trafficType string) *trafficMonitor {
return &trafficMonitor{reader, session, trafficType}
}
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
t.session.KeepAlive(t.trafficType)
return t.reader.Read(p)
}
// handleConnection handles forwarding for a single accepted connection, then closes it.
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn ReadWriteHalfCloser) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
defer span.Finish()
defer safeClose(conn, &err)
channel, err := fwd.session.OpenStreamingChannel(ctx, id)
if err != nil {
return fmt.Errorf("error opening streaming channel for new connection: %w", err)
}
// Ideally we would call safeClose again, but (*ssh.channel).Close
// appears to have a bug that causes it return io.EOF spuriously
// if its peer closed first; see github.com/golang/go/issues/38115.
defer func() {
closeErr := channel.Close()
if err == nil && closeErr != io.EOF {
err = closeErr
}
}()
// bi-directional copy of data.
errs := make(chan error, 2)
copyConn := func(w ReadWriteHalfCloser, r io.Reader) {
_, err := io.Copy(w, r)
errs <- err
// Ignore errors here, we call the full Close() later and catch that error
_ = w.CloseWrite()
}
var (
channelReader io.Reader = channel
connReader io.Reader = conn
)
// If the forwader has been configured to keep the session alive
// it will monitor the I/O and notify the session of the traffic.
if fwd.keepAlive {
channelReader = newTrafficMonitor(channelReader, fwd.session, "output")
connReader = newTrafficMonitor(connReader, fwd.session, "input")
}
go copyConn(conn, channelReader)
go copyConn(channel, connReader)
// Wait until context is cancelled or both copies are done.
// Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail.
// TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file?
for i := 0; ; {
select {
case <-ctx.Done():
return ctx.Err()
case <-errs:
i++
if i == 2 {
return nil
}
}
}
}
// safeClose reports the error (to *err) from closing the stream only
// if no other error was previously reported.
func safeClose(closer io.Closer, err *error) {
closeErr := closer.Close()
if *err == nil {
*err = closeErr
}
}