diff --git a/internal/codespaces/grpc/client.go b/internal/codespaces/grpc/client.go index cf8137006..ea8399683 100644 --- a/internal/codespaces/grpc/client.go +++ b/internal/codespaces/grpc/client.go @@ -19,8 +19,9 @@ import ( ) const ( - connectionTimeout = 5 * time.Second - requestTimeout = 30 * time.Second + serverConnectionTimeout = 5 * time.Second + requestTimeout = 30 * time.Second + portConnectionTimeout = 30 * time.Second ) const ( @@ -57,14 +58,17 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie }() // Ping the port to ensure that it is fully forwarded before continuing - liveshare.WaitForPortConnection(ctx, localAddress) + err = liveshare.WaitForPortConnection(ctx, localAddress, portConnectionTimeout) + if err != nil { + return nil, fmt.Errorf("failed to connect to local port: %w", err) + } // Attempt to connect to the port opts := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), } - ctx, _ = context.WithTimeout(ctx, connectionTimeout) + ctx, _ = context.WithTimeout(ctx, serverConnectionTimeout) conn, err := grpc.DialContext(ctx, localAddress, opts...) if err != nil { return nil, err diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go index 33dc6d1fa..ec23fa3d3 100644 --- a/pkg/liveshare/port_forwarder.go +++ b/pkg/liveshare/port_forwarder.go @@ -12,10 +12,6 @@ import ( "golang.org/x/sync/errgroup" ) -const ( - connectionTimeout = 30 * time.Second -) - type portForwardingSession interface { StartSharing(context.Context, string, int) (ChannelID, error) OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error) @@ -60,12 +56,6 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List return err } - // Ping the port to ensure that it is fully forwarded before continuing - err = WaitForPortConnection(ctx, listen.Addr().String()) - if err != nil { - return err - } - errc := make(chan error, 1) sendError := func(err error) { // Use non-blocking send, to avoid goroutines getting @@ -112,20 +102,17 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) } // Connects to and pings a given address to ensure that the server is shared and the port is forwarded. -func WaitForPortConnection(ctx context.Context, address string) error { - waitCtx, cancel := context.WithTimeout(ctx, connectionTimeout) +func WaitForPortConnection(ctx context.Context, address string, timeout time.Duration) error { + waitCtx, cancel := context.WithTimeout(ctx, timeout) g, waitCtx := errgroup.WithContext(waitCtx) defer cancel() - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - g.Go(func() error { for { select { case <-waitCtx.Done(): return fmt.Errorf("timed out waiting for connection") - case <-ticker.C: + default: // Verify that the port can be connected to conn, err := net.Dial("tcp", address) if err != nil {