diff --git a/port_forwarder.go b/port_forwarder.go index f4895bb60..fe0d7d80e 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -26,22 +26,28 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar } } +// ListenTCP calls listen on the chosen local TCP port. Zero picks an arbitrary port. +// It is provided for the convenience of callers of ForwardToLocalPort. +func Listen(port int) (net.Listener, error) { + return net.Listen("tcp", fmt.Sprintf(":%d", port)) +} + // ForwardToLocalPort forwards traffic between the container's remote -// port and a local TCP port. It accepts and handles connections on -// the local port until it encounters the first error, which may -// include context cancellation. Its error result is always non-nil. -func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) { +// port and a local TCP port, which must already be listening for +// connections. (Accepting a listener rather than a port number avoids +// races against other processes opening ports, and against a client +// connecting to the socket prematurely.) +// +// ForwardToLocalPort accepts and handles connections on the local +// port until it encounters the first error, which may include context +// cancellation. Its error result is always non-nil. The caller is +// responsible for closing the listening port. +func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, listen net.Listener) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { return err } - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort)) - if err != nil { - return fmt.Errorf("error listening on TCP port: %v", err) - } - defer safeClose(listen, &err) - errc := make(chan error, 1) sendError := func(err error) { // Use non-blocking send, to avoid goroutines getting diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 6ccb3d05e..68b658b6b 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -46,13 +46,19 @@ func TestPortForwarderStart(t *testing.T) { } defer testServer.Close() + listen, err := Listen(8000) // local port + if err != nil { + t.Fatal(err) + } + defer listen.Close() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() done := make(chan error) go func() { - const name, local, remote = "ssh", 8000, 8000 - done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, local) + const name, remote = "ssh", 8000 + done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, listen) }() go func() {