Merge branch 'main' into wip

This commit is contained in:
Alan Donovan 2021-09-01 17:55:31 -04:00
commit ca6d074333
4 changed files with 74 additions and 37 deletions

View file

@ -69,7 +69,7 @@ func (c *Client) Join(ctx context.Context) (err error) {
_, err = c.joinWorkspace(ctx)
if err != nil {
return fmt.Errorf("error joining liveshare workspace: %v", err)
return fmt.Errorf("error joining Live Share workspace: %v", err)
}
return nil

View file

@ -75,7 +75,7 @@ func TestClientJoin(t *testing.T) {
livesharetest.WithRelaySAS(connection.RelaySAS),
)
if err != nil {
t.Errorf("error creating liveshare server: %v", err)
t.Errorf("error creating Live Share server: %v", err)
}
defer server.Close()
connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https")

View file

@ -5,76 +5,116 @@ import (
"fmt"
"io"
"net"
"strconv"
)
// A PortForwader can forward ports from a remote liveshare host to localhost
// A PortForwarder forwards TCP traffic between a port on a remote
// LiveShare host and a local port.
type PortForwarder struct {
client *Client
server *Server
port int
errCh chan error
}
// NewPortForwarder creates a new PortForwader with a given client, server and port
// NewPortForwarder creates a new PortForwarder that connects a given client, server and port.
func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder {
return &PortForwarder{
client: client,
server: server,
port: port,
errCh: make(chan error),
}
}
// Start is a method to start forwarding the server to a localhost port
func (l *PortForwarder) Start(ctx context.Context) error {
ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port))
// Forward enables port forwarding. It accepts and handles TCP
// connections until it encounters the first error, which may include
// context cancellation. Its result is non-nil.
func (l *PortForwarder) Forward(ctx context.Context) (err error) {
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port))
if err != nil {
return fmt.Errorf("error listening on tcp port: %v", err)
return fmt.Errorf("error listening on TCP port: %v", err)
}
defer safeClose(listen, &err)
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting
// stuck in case of concurrent or sequential errors.
select {
case errc <- err:
default:
}
}
go func() {
for {
conn, err := ln.Accept()
conn, err := listen.Accept()
if err != nil {
l.errCh <- fmt.Errorf("error accepting incoming connection: %v", err)
sendError(err)
return
}
go l.handleConnection(ctx, conn)
go func() {
if err := l.handleConnection(ctx, conn); err != nil {
sendError(err)
}
}()
}
}()
return awaitError(ctx, errc)
}
// ForwardWithConn handles port forwarding for a single connection.
func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error {
// Create buffered channel so that send doesn't get stuck after context cancellation.
errc := make(chan error, 1)
go func() {
if err := l.handleConnection(ctx, conn); err != nil {
errc <- err
}
}()
return awaitError(ctx, errc)
}
func awaitError(ctx context.Context, errc <-chan error) error {
select {
case err := <-l.errCh:
case err := <-errc:
return err
case <-ctx.Done():
// TODO ctx.Error?
return ln.Close()
return ctx.Err() // canceled
}
}
func (l *PortForwarder) StartWithConn(ctx context.Context, conn io.ReadWriteCloser) error {
go l.handleConnection(ctx, conn)
return <-l.errCh
}
// handleConnection handles forwarding for a single accepted connection, then closes it.
func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) {
defer safeClose(conn, &err)
func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) {
channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition)
if err != nil {
l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err)
return
return fmt.Errorf("error opening streaming channel for new connection: %v", err)
}
defer safeClose(channel, &err)
copyConn := func(writer io.Writer, reader io.Reader) {
if _, err := io.Copy(writer, reader); err != nil {
channel.Close()
conn.Close()
if err != io.EOF {
l.errCh <- fmt.Errorf("tunnel connection: %v", err)
}
}
errs := make(chan error, 2)
copyConn := func(w io.Writer, r io.Reader) {
_, err := io.Copy(w, r)
errs <- err
}
go copyConn(conn, channel)
go copyConn(channel, conn)
// await result
for i := 0; i < 2; i++ {
if err := <-errs; err != nil && err != io.EOF {
return fmt.Errorf("tunnel connection: %v", err)
}
}
return nil
}
// safeClose reports the error (to *err) from closing the stream only
// if no other error was previously reported.
func safeClose(closer io.Closer, err *error) {
closeErr := closer.Close()
if *err == nil {
*err = closeErr
}
}

View file

@ -64,10 +64,7 @@ func TestPortForwarderStart(t *testing.T) {
if err := server.StartSharing(ctx, "http", 8000); err != nil {
done <- fmt.Errorf("start sharing: %v", err)
}
if err := pf.Start(ctx); err != nil {
done <- err
}
done <- nil
done <- pf.Forward(ctx)
}()
go func() {