Merge pull request #6888 from cli/cmbrose/pf-half-close

Half close port forwarding connections to fix hangs
This commit is contained in:
Caleb Brose 2023-01-23 14:26:19 -06:00 committed by GitHub
commit 90ae71b2ba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 79 additions and 41 deletions

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/cenkalti/backoff/v4"
@ -78,3 +79,18 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
Logger: sessionLogger,
})
}
// ListenTCP starts a localhost tcp listener and returns the listener and bound port
func ListenTCP(port int) (*net.TCPListener, int, error) {
addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return nil, 0, fmt.Errorf("failed to build tcp address: %w", err)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
return nil, 0, fmt.Errorf("failed to listen to local port over tcp: %w", err)
}
port = listener.Addr().(*net.TCPAddr).Port
return listener, port, nil
}

View file

@ -70,11 +70,11 @@ func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession) (Inv
// Finds a free port to listen on and creates a new RPC invoker that connects to that port
func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, error) {
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0))
listener, err := listenTCP()
if err != nil {
return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err)
return nil, err
}
localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port)
localAddress := listener.Addr().String()
invoker := &invoker{
session: session,
@ -235,6 +235,19 @@ func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options StartSS
return port, response.User, nil
}
func listenTCP() (*net.TCPListener, error) {
addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("failed to build tcp address: %w", err)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err)
}
return listener, nil
}
// Periodically check whether there is a reason to keep the connection alive, and if so, notify the codespace to do so
func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) {
ticker := time.NewTicker(interval)

View file

@ -7,7 +7,6 @@ import (
"fmt"
"io"
"log"
"net"
"time"
"github.com/cli/cli/v2/internal/codespaces/api"
@ -53,11 +52,10 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
}()
// Ensure local port is listening before client (getPostCreateOutput) connects.
listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port
listen, localPort, err := ListenTCP(0)
if err != nil {
return err
}
localPort := listen.Addr().(*net.TCPAddr).Port
progress.StartProgressIndicatorWithLabel("Fetching SSH Details")
invoker, err := rpc.CreateInvoker(ctx, session)

View file

@ -6,6 +6,7 @@ import (
"net"
"strings"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
@ -58,7 +59,7 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) {
a.StopProgressIndicator()
// Pass 0 to pick a random port
listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0))
listen, _, err := codespaces.ListenTCP(0)
if err != nil {
return err
}

View file

@ -3,7 +3,6 @@ package codespace
import (
"context"
"fmt"
"net"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/rpc"
@ -49,12 +48,11 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err
defer safeClose(session, &err)
// Ensure local port is listening before client (getPostCreateOutput) connects.
listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port
listen, localPort, err := codespaces.ListenTCP(0)
if err != nil {
return err
}
defer listen.Close()
localPort := listen.Addr().(*net.TCPAddr).Port
a.StartProgressIndicatorWithLabel("Fetching SSH Details")
invoker, err := rpc.CreateInvoker(ctx, session)

View file

@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"strings"
@ -390,7 +389,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st
for _, pair := range portPairs {
pair := pair
group.Go(func() error {
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local))
listen, _, err := codespaces.ListenTCP(pair.local)
if err != nil {
return err
}

View file

@ -6,9 +6,7 @@ import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"os"
"os/exec"
"path"
@ -188,7 +186,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
if opts.stdio {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true)
stdio := newReadWriteCloser(os.Stdin, os.Stdout)
stdio := liveshare.NewReadWriteHalfCloser(os.Stdin, os.Stdout)
err := fwd.Forward(ctx, stdio) // always non-nil
return fmt.Errorf("tunnel closed: %w", err)
}
@ -199,12 +197,11 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
// Ensure local port is listening before client (Shell) connects.
// Unless the user specifies a server port, localSSHServerPort is 0
// and thus the client will pick a random port.
listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", localSSHServerPort))
listen, localSSHServerPort, err := codespaces.ListenTCP(localSSHServerPort)
if err != nil {
return err
}
defer listen.Close()
localSSHServerPort = listen.Addr().(*net.TCPAddr).Port
connectDestination := opts.profile
if connectDestination == "" {
@ -745,21 +742,3 @@ func (fl *fileLogger) Name() string {
func (fl *fileLogger) Close() error {
return fl.f.Close()
}
type combinedReadWriteCloser struct {
io.ReadCloser
io.WriteCloser
}
func newReadWriteCloser(reader io.ReadCloser, writer io.WriteCloser) io.ReadWriteCloser {
return &combinedReadWriteCloser{reader, writer}
}
func (crwc *combinedReadWriteCloser) Close() error {
werr := crwc.WriteCloser.Close()
rerr := crwc.ReadCloser.Close()
if werr != nil {
return werr
}
return rerr
}

View file

@ -16,6 +16,33 @@ type portForwardingSession interface {
KeepAlive(string)
}
type ReadWriteHalfCloser interface {
io.ReadWriteCloser
CloseWrite() error
}
type combinedReadWriteHalfCloser struct {
io.ReadCloser
io.WriteCloser
}
func NewReadWriteHalfCloser(reader io.ReadCloser, writer io.WriteCloser) ReadWriteHalfCloser {
return &combinedReadWriteHalfCloser{reader, writer}
}
func (crwc *combinedReadWriteHalfCloser) Close() error {
werr := crwc.WriteCloser.Close()
rerr := crwc.ReadCloser.Close()
if werr != nil {
return werr
}
return rerr
}
func (crwc *combinedReadWriteHalfCloser) CloseWrite() error {
return crwc.WriteCloser.Close()
}
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
// container to a local destination such as a network port or Go reader/writer.
type PortForwarder struct {
@ -48,7 +75,7 @@ func NewPortForwarder(session portForwardingSession, name string, remotePort int
// until it encounters the first error, which may include context
// cancellation. Its error result is always non-nil. The caller is
// responsible for closing the listening port.
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) {
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen *net.TCPListener) (err error) {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
return err
@ -65,7 +92,7 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List
}
go func() {
for {
conn, err := listen.Accept()
conn, err := listen.AcceptTCP()
if err != nil {
sendError(err)
return
@ -84,7 +111,7 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List
// Forward forwards traffic between the container's remote port and
// the specified read/write stream. On return, the stream is closed.
func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error {
func (fwd *PortForwarder) Forward(ctx context.Context, conn ReadWriteHalfCloser) error {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
conn.Close()
@ -143,7 +170,7 @@ func (t *trafficMonitor) Read(p []byte) (n int, err error) {
}
// handleConnection handles forwarding for a single accepted connection, then closes it.
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn io.ReadWriteCloser) (err error) {
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn ReadWriteHalfCloser) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
defer span.Finish()
@ -165,9 +192,12 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, co
// bi-directional copy of data.
errs := make(chan error, 2)
copyConn := func(w io.Writer, r io.Reader) {
copyConn := func(w ReadWriteHalfCloser, r io.Reader) {
_, err := io.Copy(w, r)
errs <- err
// Ignore errors here, we call the full Close() later and catch that error
_ = w.CloseWrite()
}
var (

View file

@ -71,6 +71,10 @@ func TestPortForwarderStart(t *testing.T) {
t.Fatal(err)
}
defer listen.Close()
tcpListener, ok := listen.(*net.TCPListener)
if !ok {
t.Fatal("net.Listen did not return a TCPListener")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -82,7 +86,7 @@ func TestPortForwarderStart(t *testing.T) {
done := make(chan error, 2)
go func() {
done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, listen)
done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, tcpListener)
}()
go func() {