241 lines
6.6 KiB
Go
241 lines
6.6 KiB
Go
package liveshare
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
|
|
"github.com/opentracing/opentracing-go"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type portForwardingSession interface {
|
|
StartSharing(context.Context, string, int) (ChannelID, error)
|
|
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
|
|
KeepAlive(string)
|
|
}
|
|
|
|
type ReadWriteHalfCloser interface {
|
|
io.ReadWriteCloser
|
|
CloseWrite() error
|
|
}
|
|
|
|
type combinedReadWriteHalfCloser struct {
|
|
io.ReadCloser
|
|
io.WriteCloser
|
|
}
|
|
|
|
func NewReadWriteHalfCloser(reader io.ReadCloser, writer io.WriteCloser) ReadWriteHalfCloser {
|
|
return &combinedReadWriteHalfCloser{reader, writer}
|
|
}
|
|
|
|
func (crwc *combinedReadWriteHalfCloser) Close() error {
|
|
werr := crwc.WriteCloser.Close()
|
|
rerr := crwc.ReadCloser.Close()
|
|
if werr != nil {
|
|
return werr
|
|
}
|
|
return rerr
|
|
}
|
|
|
|
func (crwc *combinedReadWriteHalfCloser) CloseWrite() error {
|
|
return crwc.WriteCloser.Close()
|
|
}
|
|
|
|
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
|
|
// container to a local destination such as a network port or Go reader/writer.
|
|
type PortForwarder struct {
|
|
session portForwardingSession
|
|
name string
|
|
remotePort int
|
|
keepAlive bool
|
|
}
|
|
|
|
// NewPortForwarder returns a new PortForwarder for the specified
|
|
// remote port and Live Share session. The name describes the purpose
|
|
// of the remote port or service. The keepAlive flag indicates whether
|
|
// the session should be kept alive with port forwarding traffic.
|
|
func NewPortForwarder(session portForwardingSession, name string, remotePort int, keepAlive bool) *PortForwarder {
|
|
return &PortForwarder{
|
|
session: session,
|
|
name: name,
|
|
remotePort: remotePort,
|
|
keepAlive: keepAlive,
|
|
}
|
|
}
|
|
|
|
// ForwardToListener forwards traffic between the container's remote
|
|
// port and a local port, which must already be listening for
|
|
// connections. (Accepting a listener rather than a port number avoids
|
|
// races against other processes opening ports, and against a client
|
|
// connecting to the socket prematurely.)
|
|
//
|
|
// ForwardToListener accepts and handles connections on the local port
|
|
// until it encounters the first error, which may include context
|
|
// cancellation. Its error result is always non-nil. The caller is
|
|
// responsible for closing the listening port.
|
|
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen *net.TCPListener) (err error) {
|
|
id, err := fwd.shareRemotePort(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
errc := make(chan error, 1)
|
|
sendError := func(err error) {
|
|
// Use non-blocking send, to avoid goroutines getting
|
|
// stuck in case of concurrent or sequential errors.
|
|
select {
|
|
case errc <- err:
|
|
default:
|
|
}
|
|
}
|
|
go func() {
|
|
for {
|
|
conn, err := listen.AcceptTCP()
|
|
if err != nil {
|
|
sendError(err)
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
if err := fwd.handleConnection(ctx, id, conn); err != nil {
|
|
sendError(err)
|
|
}
|
|
}()
|
|
}
|
|
}()
|
|
|
|
return awaitError(ctx, errc)
|
|
}
|
|
|
|
// Forward forwards traffic between the container's remote port and
|
|
// the specified read/write stream. On return, the stream is closed.
|
|
func (fwd *PortForwarder) Forward(ctx context.Context, conn ReadWriteHalfCloser) error {
|
|
id, err := fwd.shareRemotePort(ctx)
|
|
if err != nil {
|
|
conn.Close()
|
|
return err
|
|
}
|
|
|
|
// Create buffered channel so that send doesn't get stuck after context cancellation.
|
|
errc := make(chan error, 1)
|
|
go func() {
|
|
errc <- fwd.handleConnection(ctx, id, conn)
|
|
}()
|
|
return awaitError(ctx, errc)
|
|
}
|
|
|
|
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) {
|
|
id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort)
|
|
if err != nil {
|
|
err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err)
|
|
}
|
|
|
|
return id, err
|
|
}
|
|
|
|
func awaitError(ctx context.Context, errc <-chan error) error {
|
|
select {
|
|
case err := <-errc:
|
|
return err
|
|
case <-ctx.Done():
|
|
return ctx.Err() // canceled
|
|
}
|
|
}
|
|
|
|
type trafficMonitorSession interface {
|
|
KeepAlive(string)
|
|
}
|
|
|
|
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
|
|
// it of the traffic type during Read operations.
|
|
type trafficMonitor struct {
|
|
reader io.Reader
|
|
|
|
session trafficMonitorSession
|
|
trafficType string
|
|
}
|
|
|
|
// newTrafficMonitor returns a new trafficMonitor for the specified
|
|
// session and traffic type. It wraps the provided io.Reader with its own
|
|
// Read method.
|
|
func newTrafficMonitor(reader io.Reader, session trafficMonitorSession, trafficType string) *trafficMonitor {
|
|
return &trafficMonitor{reader, session, trafficType}
|
|
}
|
|
|
|
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
|
|
t.session.KeepAlive(t.trafficType)
|
|
return t.reader.Read(p)
|
|
}
|
|
|
|
// handleConnection handles forwarding for a single accepted connection, then closes it.
|
|
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn ReadWriteHalfCloser) (err error) {
|
|
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
|
|
defer span.Finish()
|
|
|
|
defer safeClose(conn, &err)
|
|
|
|
channel, err := fwd.session.OpenStreamingChannel(ctx, id)
|
|
if err != nil {
|
|
return fmt.Errorf("error opening streaming channel for new connection: %w", err)
|
|
}
|
|
// Ideally we would call safeClose again, but (*ssh.channel).Close
|
|
// appears to have a bug that causes it return io.EOF spuriously
|
|
// if its peer closed first; see github.com/golang/go/issues/38115.
|
|
defer func() {
|
|
closeErr := channel.Close()
|
|
if err == nil && closeErr != io.EOF {
|
|
err = closeErr
|
|
}
|
|
}()
|
|
|
|
// bi-directional copy of data.
|
|
errs := make(chan error, 2)
|
|
copyConn := func(w ReadWriteHalfCloser, r io.Reader) {
|
|
_, err := io.Copy(w, r)
|
|
errs <- err
|
|
|
|
// Ignore errors here, we call the full Close() later and catch that error
|
|
_ = w.CloseWrite()
|
|
}
|
|
|
|
var (
|
|
channelReader io.Reader = channel
|
|
connReader io.Reader = conn
|
|
)
|
|
|
|
// If the forwader has been configured to keep the session alive
|
|
// it will monitor the I/O and notify the session of the traffic.
|
|
if fwd.keepAlive {
|
|
channelReader = newTrafficMonitor(channelReader, fwd.session, "output")
|
|
connReader = newTrafficMonitor(connReader, fwd.session, "input")
|
|
}
|
|
|
|
go copyConn(conn, channelReader)
|
|
go copyConn(channel, connReader)
|
|
|
|
// Wait until context is cancelled or both copies are done.
|
|
// Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail.
|
|
// TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file?
|
|
for i := 0; ; {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-errs:
|
|
i++
|
|
if i == 2 {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// safeClose reports the error (to *err) from closing the stream only
|
|
// if no other error was previously reported.
|
|
func safeClose(closer io.Closer, err *error) {
|
|
closeErr := closer.Close()
|
|
if *err == nil {
|
|
*err = closeErr
|
|
}
|
|
}
|