Merge pull request #13 from github/listen-race

move Listen call into clients to avoid race
This commit is contained in:
Alan Donovan 2021-09-03 10:05:02 -04:00 committed by GitHub
commit 10a129b764
2 changed files with 26 additions and 13 deletions

View file

@ -26,22 +26,29 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar
}
}
// ForwardToLocalPort forwards traffic between the container's remote
// port and a local TCP port. It accepts and handles connections on
// the local port until it encounters the first error, which may
// include context cancellation. Its error result is always non-nil.
func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) {
// ListenTCP calls listen on the chosen local TCP port. Zero picks an
// arbitrary port. It is provided for the convenience of callers of
// ForwardToListener.
func ListenTCP(port int) (net.Listener, error) {
return net.Listen("tcp", fmt.Sprintf(":%d", port))
}
// ForwardToListener forwards traffic between the container's remote
// port and a local port, which must already be listening for
// connections. (Accepting a listener rather than a port number avoids
// races against other processes opening ports, and against a client
// connecting to the socket prematurely.)
//
// ForwardToListener accepts and handles connections on the local port
// until it encounters the first error, which may include context
// cancellation. Its error result is always non-nil. The caller is
// responsible for closing the listening port.
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
return err
}
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort))
if err != nil {
return fmt.Errorf("error listening on TCP port: %v", err)
}
defer safeClose(listen, &err)
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting

View file

@ -46,13 +46,19 @@ func TestPortForwarderStart(t *testing.T) {
}
defer testServer.Close()
listen, err := ListenTCP(8000) // local port
if err != nil {
t.Fatal(err)
}
defer listen.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan error)
go func() {
const name, local, remote = "ssh", 8000, 8000
done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, local)
const name, remote = "ssh", 8000
done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen)
}()
go func() {