Merge pull request #8 from github/portfwd-errors
handle errors in port forwarding
This commit is contained in:
commit
fc4c678d03
2 changed files with 77 additions and 41 deletions
|
|
@ -5,77 +5,116 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// A PortForwader can forward ports from a remote liveshare host to localhost
|
||||
// A PortForwarder forwards TCP traffic between a port on a remote
|
||||
// LiveShare host and a local port.
|
||||
type PortForwarder struct {
|
||||
client *Client
|
||||
server *Server
|
||||
port int
|
||||
errCh chan error
|
||||
}
|
||||
|
||||
// NewPortForwarder creates a new PortForwader with a given client, server and port
|
||||
// NewPortForwarder creates a new PortForwarder that connects a given client, server and port.
|
||||
func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder {
|
||||
return &PortForwarder{
|
||||
client: client,
|
||||
server: server,
|
||||
port: port,
|
||||
errCh: make(chan error),
|
||||
}
|
||||
}
|
||||
|
||||
// Start is a method to start forwarding the server to a localhost port
|
||||
func (l *PortForwarder) Start(ctx context.Context) error {
|
||||
ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port))
|
||||
// Forward enables port forwarding. It accepts and handles TCP
|
||||
// connections until it encounters the first error, which may include
|
||||
// context cancellation. Its result is non-nil.
|
||||
func (l *PortForwarder) Forward(ctx context.Context) (err error) {
|
||||
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error listening on tcp port: %v", err)
|
||||
return fmt.Errorf("error listening on TCP port: %v", err)
|
||||
}
|
||||
defer safeClose(listen, &err)
|
||||
|
||||
errc := make(chan error, 1)
|
||||
sendError := func(err error) {
|
||||
// Use non-blocking send, to avoid goroutines getting
|
||||
// stuck in case of concurrent or sequential errors.
|
||||
select {
|
||||
case errc <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
conn, err := listen.Accept()
|
||||
if err != nil {
|
||||
l.errCh <- fmt.Errorf("error accepting incoming connection: %v", err)
|
||||
sendError(err)
|
||||
return
|
||||
}
|
||||
|
||||
go l.handleConnection(ctx, conn)
|
||||
go func() {
|
||||
if err := l.handleConnection(ctx, conn); err != nil {
|
||||
sendError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
// ForwardWithConn handles port forwarding for a single connection.
|
||||
func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error {
|
||||
// Create buffered channel so that send doesn't get stuck after context cancellation.
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
if err := l.handleConnection(ctx, conn); err != nil {
|
||||
errc <- err
|
||||
}
|
||||
}()
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
func awaitError(ctx context.Context, errc <-chan error) error {
|
||||
select {
|
||||
case err := <-l.errCh:
|
||||
case err := <-errc:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ln.Close()
|
||||
return ctx.Err() // canceled
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles forwarding for a single accepted connection, then closes it.
|
||||
func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) {
|
||||
defer safeClose(conn, &err)
|
||||
|
||||
channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening streaming channel for new connection: %v", err)
|
||||
}
|
||||
defer safeClose(channel, &err)
|
||||
|
||||
errs := make(chan error, 2)
|
||||
copyConn := func(w io.Writer, r io.Reader) {
|
||||
_, err := io.Copy(w, r)
|
||||
errs <- err
|
||||
}
|
||||
go copyConn(conn, channel)
|
||||
go copyConn(channel, conn)
|
||||
|
||||
// await result
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-errs; err != nil && err != io.EOF {
|
||||
return fmt.Errorf("tunnel connection: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *PortForwarder) StartWithConn(ctx context.Context, conn io.ReadWriteCloser) error {
|
||||
go l.handleConnection(ctx, conn)
|
||||
return <-l.errCh
|
||||
}
|
||||
|
||||
func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) {
|
||||
channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition)
|
||||
if err != nil {
|
||||
l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err)
|
||||
return
|
||||
// safeClose reports the error (to *err) from closing the stream only
|
||||
// if no other error was previously reported.
|
||||
func safeClose(closer io.Closer, err *error) {
|
||||
closeErr := closer.Close()
|
||||
if *err == nil {
|
||||
*err = closeErr
|
||||
}
|
||||
|
||||
copyConn := func(writer io.Writer, reader io.Reader) {
|
||||
if _, err := io.Copy(writer, reader); err != nil {
|
||||
channel.Close()
|
||||
conn.Close()
|
||||
if err != io.EOF {
|
||||
l.errCh <- fmt.Errorf("tunnel connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go copyConn(conn, channel)
|
||||
go copyConn(channel, conn)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,10 +63,7 @@ func TestPortForwarderStart(t *testing.T) {
|
|||
if err := server.StartSharing(ctx, "http", 8000); err != nil {
|
||||
done <- fmt.Errorf("start sharing: %v", err)
|
||||
}
|
||||
if err := pf.Start(ctx); err != nil {
|
||||
done <- err
|
||||
}
|
||||
done <- nil
|
||||
done <- pf.Forward(ctx)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue