diff --git a/port_forwarder.go b/port_forwarder.go index 7d3363ba2..400d6ac97 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -84,9 +84,7 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) // Create buffered channel so that send doesn't get stuck after context cancellation. errc := make(chan error, 1) go func() { - if err := fwd.handleConnection(ctx, id, conn); err != nil { - errc <- err - } + errc <- fwd.handleConnection(ctx, id, conn) }() return awaitError(ctx, errc) } @@ -129,6 +127,7 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co } }() + // bi-directional copy of data. errs := make(chan error, 2) copyConn := func(w io.Writer, r io.Reader) { _, err := io.Copy(w, r) @@ -137,13 +136,20 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co go copyConn(conn, channel) go copyConn(channel, conn) - // await result - for i := 0; i < 2; i++ { - if err := <-errs; err != nil && err != io.EOF { - return fmt.Errorf("tunnel connection: %v", err) + // Wait until context is cancelled or both copies are done. + // Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail. + // TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file? + for i := 0; ; { + select { + case <-ctx.Done(): + return ctx.Err() + case <-errs: + i++ + if i == 2 { + return nil + } } } - return nil } // safeClose reports the error (to *err) from closing the stream only