Merge pull request #6888 from cli/cmbrose/pf-half-close
Half close port forwarding connections to fix hangs
This commit is contained in:
commit
90ae71b2ba
9 changed files with 79 additions and 41 deletions
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
|
@ -78,3 +79,18 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
|
|||
Logger: sessionLogger,
|
||||
})
|
||||
}
|
||||
|
||||
// ListenTCP starts a localhost tcp listener and returns the listener and bound port
|
||||
func ListenTCP(port int) (*net.TCPListener, int, error) {
|
||||
addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to build tcp address: %w", err)
|
||||
}
|
||||
listener, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to listen to local port over tcp: %w", err)
|
||||
}
|
||||
port = listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
return listener, port, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,11 +70,11 @@ func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession) (Inv
|
|||
|
||||
// Finds a free port to listen on and creates a new RPC invoker that connects to that port
|
||||
func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, error) {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0))
|
||||
listener, err := listenTCP()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port)
|
||||
localAddress := listener.Addr().String()
|
||||
|
||||
invoker := &invoker{
|
||||
session: session,
|
||||
|
|
@ -235,6 +235,19 @@ func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options StartSS
|
|||
return port, response.User, nil
|
||||
}
|
||||
|
||||
func listenTCP() (*net.TCPListener, error) {
|
||||
addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build tcp address: %w", err)
|
||||
}
|
||||
listener, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err)
|
||||
}
|
||||
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// Periodically check whether there is a reason to keep the connection alive, and if so, notify the codespace to do so
|
||||
func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
|
|
@ -53,11 +52,10 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
|
|||
}()
|
||||
|
||||
// Ensure local port is listening before client (getPostCreateOutput) connects.
|
||||
listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port
|
||||
listen, localPort, err := ListenTCP(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
localPort := listen.Addr().(*net.TCPAddr).Port
|
||||
|
||||
progress.StartProgressIndicatorWithLabel("Fetching SSH Details")
|
||||
invoker, err := rpc.CreateInvoker(ctx, session)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
|
|
@ -58,7 +59,7 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) {
|
|||
a.StopProgressIndicator()
|
||||
|
||||
// Pass 0 to pick a random port
|
||||
listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0))
|
||||
listen, _, err := codespaces.ListenTCP(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package codespace
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc"
|
||||
|
|
@ -49,12 +48,11 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err
|
|||
defer safeClose(session, &err)
|
||||
|
||||
// Ensure local port is listening before client (getPostCreateOutput) connects.
|
||||
listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port
|
||||
listen, localPort, err := codespaces.ListenTCP(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listen.Close()
|
||||
localPort := listen.Addr().(*net.TCPAddr).Port
|
||||
|
||||
a.StartProgressIndicatorWithLabel("Fetching SSH Details")
|
||||
invoker, err := rpc.CreateInvoker(ctx, session)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -390,7 +389,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st
|
|||
for _, pair := range portPairs {
|
||||
pair := pair
|
||||
group.Go(func() error {
|
||||
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local))
|
||||
listen, _, err := codespaces.ListenTCP(pair.local)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
|
|
@ -188,7 +186,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
|
|||
|
||||
if opts.stdio {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true)
|
||||
stdio := newReadWriteCloser(os.Stdin, os.Stdout)
|
||||
stdio := liveshare.NewReadWriteHalfCloser(os.Stdin, os.Stdout)
|
||||
err := fwd.Forward(ctx, stdio) // always non-nil
|
||||
return fmt.Errorf("tunnel closed: %w", err)
|
||||
}
|
||||
|
|
@ -199,12 +197,11 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
|
|||
// Ensure local port is listening before client (Shell) connects.
|
||||
// Unless the user specifies a server port, localSSHServerPort is 0
|
||||
// and thus the client will pick a random port.
|
||||
listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", localSSHServerPort))
|
||||
listen, localSSHServerPort, err := codespaces.ListenTCP(localSSHServerPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listen.Close()
|
||||
localSSHServerPort = listen.Addr().(*net.TCPAddr).Port
|
||||
|
||||
connectDestination := opts.profile
|
||||
if connectDestination == "" {
|
||||
|
|
@ -745,21 +742,3 @@ func (fl *fileLogger) Name() string {
|
|||
func (fl *fileLogger) Close() error {
|
||||
return fl.f.Close()
|
||||
}
|
||||
|
||||
type combinedReadWriteCloser struct {
|
||||
io.ReadCloser
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
func newReadWriteCloser(reader io.ReadCloser, writer io.WriteCloser) io.ReadWriteCloser {
|
||||
return &combinedReadWriteCloser{reader, writer}
|
||||
}
|
||||
|
||||
func (crwc *combinedReadWriteCloser) Close() error {
|
||||
werr := crwc.WriteCloser.Close()
|
||||
rerr := crwc.ReadCloser.Close()
|
||||
if werr != nil {
|
||||
return werr
|
||||
}
|
||||
return rerr
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,33 @@ type portForwardingSession interface {
|
|||
KeepAlive(string)
|
||||
}
|
||||
|
||||
type ReadWriteHalfCloser interface {
|
||||
io.ReadWriteCloser
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
type combinedReadWriteHalfCloser struct {
|
||||
io.ReadCloser
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
func NewReadWriteHalfCloser(reader io.ReadCloser, writer io.WriteCloser) ReadWriteHalfCloser {
|
||||
return &combinedReadWriteHalfCloser{reader, writer}
|
||||
}
|
||||
|
||||
func (crwc *combinedReadWriteHalfCloser) Close() error {
|
||||
werr := crwc.WriteCloser.Close()
|
||||
rerr := crwc.ReadCloser.Close()
|
||||
if werr != nil {
|
||||
return werr
|
||||
}
|
||||
return rerr
|
||||
}
|
||||
|
||||
func (crwc *combinedReadWriteHalfCloser) CloseWrite() error {
|
||||
return crwc.WriteCloser.Close()
|
||||
}
|
||||
|
||||
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
|
||||
// container to a local destination such as a network port or Go reader/writer.
|
||||
type PortForwarder struct {
|
||||
|
|
@ -48,7 +75,7 @@ func NewPortForwarder(session portForwardingSession, name string, remotePort int
|
|||
// until it encounters the first error, which may include context
|
||||
// cancellation. Its error result is always non-nil. The caller is
|
||||
// responsible for closing the listening port.
|
||||
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) {
|
||||
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen *net.TCPListener) (err error) {
|
||||
id, err := fwd.shareRemotePort(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -65,7 +92,7 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List
|
|||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listen.Accept()
|
||||
conn, err := listen.AcceptTCP()
|
||||
if err != nil {
|
||||
sendError(err)
|
||||
return
|
||||
|
|
@ -84,7 +111,7 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List
|
|||
|
||||
// Forward forwards traffic between the container's remote port and
|
||||
// the specified read/write stream. On return, the stream is closed.
|
||||
func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error {
|
||||
func (fwd *PortForwarder) Forward(ctx context.Context, conn ReadWriteHalfCloser) error {
|
||||
id, err := fwd.shareRemotePort(ctx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
|
|
@ -143,7 +170,7 @@ func (t *trafficMonitor) Read(p []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
// handleConnection handles forwarding for a single accepted connection, then closes it.
|
||||
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn io.ReadWriteCloser) (err error) {
|
||||
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn ReadWriteHalfCloser) (err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
|
||||
defer span.Finish()
|
||||
|
||||
|
|
@ -165,9 +192,12 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, co
|
|||
|
||||
// bi-directional copy of data.
|
||||
errs := make(chan error, 2)
|
||||
copyConn := func(w io.Writer, r io.Reader) {
|
||||
copyConn := func(w ReadWriteHalfCloser, r io.Reader) {
|
||||
_, err := io.Copy(w, r)
|
||||
errs <- err
|
||||
|
||||
// Ignore errors here, we call the full Close() later and catch that error
|
||||
_ = w.CloseWrite()
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
|||
|
|
@ -71,6 +71,10 @@ func TestPortForwarderStart(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
defer listen.Close()
|
||||
tcpListener, ok := listen.(*net.TCPListener)
|
||||
if !ok {
|
||||
t.Fatal("net.Listen did not return a TCPListener")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
|
@ -82,7 +86,7 @@ func TestPortForwarderStart(t *testing.T) {
|
|||
|
||||
done := make(chan error, 2)
|
||||
go func() {
|
||||
done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, listen)
|
||||
done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, tcpListener)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue