117 lines
2.9 KiB
Go
117 lines
2.9 KiB
Go
package liveshare
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
)
|
|
|
|
// A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session.
|
|
type PortForwarder struct {
|
|
session *Session
|
|
port int
|
|
}
|
|
|
|
// NewPortForwarder creates a new PortForwarder for a given Live Share session and local TCP port.
|
|
func NewPortForwarder(session *Session, port int) *PortForwarder {
|
|
return &PortForwarder{
|
|
session: session,
|
|
port: port,
|
|
}
|
|
}
|
|
|
|
// Forward enables port forwarding. It accepts and handles TCP
|
|
// connections until it encounters the first error, which may include
|
|
// context cancellation. Its result is non-nil.
|
|
func (l *PortForwarder) Forward(ctx context.Context) (err error) {
|
|
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port))
|
|
if err != nil {
|
|
return fmt.Errorf("error listening on TCP port: %v", err)
|
|
}
|
|
defer safeClose(listen, &err)
|
|
|
|
errc := make(chan error, 1)
|
|
sendError := func(err error) {
|
|
// Use non-blocking send, to avoid goroutines getting
|
|
// stuck in case of concurrent or sequential errors.
|
|
select {
|
|
case errc <- err:
|
|
default:
|
|
}
|
|
}
|
|
go func() {
|
|
for {
|
|
conn, err := listen.Accept()
|
|
if err != nil {
|
|
sendError(err)
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
if err := l.handleConnection(ctx, conn); err != nil {
|
|
sendError(err)
|
|
}
|
|
}()
|
|
}
|
|
}()
|
|
|
|
return awaitError(ctx, errc)
|
|
}
|
|
|
|
// ForwardWithConn handles port forwarding for a single connection.
|
|
func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error {
|
|
// Create buffered channel so that send doesn't get stuck after context cancellation.
|
|
errc := make(chan error, 1)
|
|
go func() {
|
|
if err := l.handleConnection(ctx, conn); err != nil {
|
|
errc <- err
|
|
}
|
|
}()
|
|
return awaitError(ctx, errc)
|
|
}
|
|
|
|
func awaitError(ctx context.Context, errc <-chan error) error {
|
|
select {
|
|
case err := <-errc:
|
|
return err
|
|
case <-ctx.Done():
|
|
return ctx.Err() // canceled
|
|
}
|
|
}
|
|
|
|
// handleConnection handles forwarding for a single accepted connection, then closes it.
|
|
func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) {
|
|
defer safeClose(conn, &err)
|
|
|
|
channel, err := l.session.openStreamingChannel(ctx, l.session.streamName, l.session.streamCondition)
|
|
if err != nil {
|
|
return fmt.Errorf("error opening streaming channel for new connection: %v", err)
|
|
}
|
|
defer safeClose(channel, &err)
|
|
|
|
errs := make(chan error, 2)
|
|
copyConn := func(w io.Writer, r io.Reader) {
|
|
_, err := io.Copy(w, r)
|
|
errs <- err
|
|
}
|
|
go copyConn(conn, channel)
|
|
go copyConn(channel, conn)
|
|
|
|
// await result
|
|
for i := 0; i < 2; i++ {
|
|
if err := <-errs; err != nil && err != io.EOF {
|
|
return fmt.Errorf("tunnel connection: %v", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// safeClose reports the error (to *err) from closing the stream only
|
|
// if no other error was previously reported.
|
|
func safeClose(closer io.Closer, err *error) {
|
|
closeErr := closer.Close()
|
|
if *err == nil {
|
|
*err = closeErr
|
|
}
|
|
}
|