Ensure port is forwarded and server is shared

This commit is contained in:
David Gardiner 2022-10-07 15:37:04 -07:00
parent 341fc6c3f7
commit a090b17e38
3 changed files with 54 additions and 3 deletions

View file

@ -49,14 +49,15 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie
}
// Tunnel the remote gRPC server port to the local port
localPort := listener.Addr().(*net.TCPAddr).Port
localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port)
internalTunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true)
internalTunnelClosed <- fwd.ForwardToListener(ctx, listener)
}()
time.Sleep(time.Millisecond)
// Ping the port to ensure that it is fully forwarded before continuing
liveshare.WaitForPortConnection(ctx, localAddress)
// Attempt to connect to the port
opts := []grpc.DialOption{
@ -64,7 +65,7 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie
grpc.WithBlock(),
}
ctx, _ = context.WithTimeout(ctx, connectionTimeout)
conn, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", localPort), opts...)
conn, err := grpc.DialContext(ctx, localAddress, opts...)
if err != nil {
return nil, err
}

View file

@ -21,6 +21,7 @@ func (s *Session) StartSharing(ctx context.Context, sessionName string, port int
return liveshare.ChannelID{}, nil
}
// Creates mock SSH channel connected to the mock gRPC server
func (s *Session) OpenStreamingChannel(ctx context.Context, id liveshare.ChannelID) (ssh.Channel, error) {
dialer := net.Dialer{}
conn, err := dialer.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))

View file

@ -5,9 +5,15 @@ import (
"fmt"
"io"
"net"
"time"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
const (
connectionTimeout = 30 * time.Second
)
type portForwardingSession interface {
@ -54,6 +60,12 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List
return err
}
// Ping the port to ensure that it is fully forwarded before continuing
err = WaitForPortConnection(ctx, listen.Addr().String())
if err != nil {
return err
}
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting
@ -99,6 +111,43 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser)
return awaitError(ctx, errc)
}
// Connects to and pings a given address to ensure that the server is shared and the port is forwarded.
func WaitForPortConnection(ctx context.Context, address string) error {
waitCtx, cancel := context.WithTimeout(ctx, connectionTimeout)
g, waitCtx := errgroup.WithContext(waitCtx)
defer cancel()
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
g.Go(func() error {
for {
select {
case <-waitCtx.Done():
return fmt.Errorf("timed out waiting for connection")
case <-ticker.C:
// Verify that the port can be connected to
conn, err := net.Dial("tcp", address)
if err != nil {
continue
}
defer conn.Close()
// Send a ping and make sure it succeed
_, err = conn.Write([]byte("ping"))
if err != nil {
continue
}
return nil
}
}
})
return g.Wait()
}
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) {
id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort)
if err != nil {