diff --git a/internal/codespaces/grpc/client.go b/internal/codespaces/grpc/client.go index 7ceeac7d1..cf8137006 100644 --- a/internal/codespaces/grpc/client.go +++ b/internal/codespaces/grpc/client.go @@ -49,14 +49,15 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie } // Tunnel the remote gRPC server port to the local port - localPort := listener.Addr().(*net.TCPAddr).Port + localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port) internalTunnelClosed := make(chan error, 1) go func() { fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true) internalTunnelClosed <- fwd.ForwardToListener(ctx, listener) }() - time.Sleep(time.Millisecond) + // Ping the port to ensure that it is fully forwarded before continuing + liveshare.WaitForPortConnection(ctx, localAddress) // Attempt to connect to the port opts := []grpc.DialOption{ @@ -64,7 +65,7 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie grpc.WithBlock(), } ctx, _ = context.WithTimeout(ctx, connectionTimeout) - conn, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", localPort), opts...) + conn, err := grpc.DialContext(ctx, localAddress, opts...) if err != nil { return nil, err } diff --git a/internal/codespaces/grpc/test/session.go b/internal/codespaces/grpc/test/session.go index e29027dc6..ec4d69649 100644 --- a/internal/codespaces/grpc/test/session.go +++ b/internal/codespaces/grpc/test/session.go @@ -21,6 +21,7 @@ func (s *Session) StartSharing(ctx context.Context, sessionName string, port int return liveshare.ChannelID{}, nil } +// Creates mock SSH channel connected to the mock gRPC server func (s *Session) OpenStreamingChannel(ctx context.Context, id liveshare.ChannelID) (ssh.Channel, error) { dialer := net.Dialer{} conn, err := dialer.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go index f042eeaea..33dc6d1fa 100644 --- a/pkg/liveshare/port_forwarder.go +++ b/pkg/liveshare/port_forwarder.go @@ -5,9 +5,15 @@ import ( "fmt" "io" "net" + "time" "github.com/opentracing/opentracing-go" "golang.org/x/crypto/ssh" + "golang.org/x/sync/errgroup" +) + +const ( + connectionTimeout = 30 * time.Second ) type portForwardingSession interface { @@ -54,6 +60,12 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List return err } + // Ping the port to ensure that it is fully forwarded before continuing + err = WaitForPortConnection(ctx, listen.Addr().String()) + if err != nil { + return err + } + errc := make(chan error, 1) sendError := func(err error) { // Use non-blocking send, to avoid goroutines getting @@ -99,6 +111,43 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) return awaitError(ctx, errc) } +// Connects to and pings a given address to ensure that the server is shared and the port is forwarded. +func WaitForPortConnection(ctx context.Context, address string) error { + waitCtx, cancel := context.WithTimeout(ctx, connectionTimeout) + g, waitCtx := errgroup.WithContext(waitCtx) + defer cancel() + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + g.Go(func() error { + for { + select { + case <-waitCtx.Done(): + return fmt.Errorf("timed out waiting for connection") + case <-ticker.C: + // Verify that the port can be connected to + conn, err := net.Dial("tcp", address) + if err != nil { + continue + } + + defer conn.Close() + + // Send a ping and make sure it succeed + _, err = conn.Write([]byte("ping")) + if err != nil { + continue + } + + return nil + } + } + }) + + return g.Wait() +} + func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) { id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort) if err != nil {