diff --git a/port_forwarder.go b/port_forwarder.go index 1351025cb..f47d11565 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -123,13 +123,28 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co }() // Bi-directional copy of data. - // If any individual connection has an error, we can safely ignore them - // and defer to connection clients to handle data loss as necessary. - go io.Copy(conn, channel) - go io.Copy(channel, conn) + errs := make(chan error, 2) + copyConn := func(w io.Writer, r io.Reader) { + _, err := io.Copy(w, r) + errs <- err + } + go copyConn(conn, channel) + go copyConn(channel, conn) - <-ctx.Done() - return ctx.Err() + // wait until context is cancelled or we've received two io.EOF +Loop: + for i := 0; i < 2; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errs: + if err != nil && err != io.EOF { + break Loop // non-EOF errors stop connection handling + } + } + } + + return nil } // safeClose reports the error (to *err) from closing the stream only