Have cilents call port connection function

This commit is contained in:
David Gardiner 2022-10-10 11:56:44 -07:00
parent a090b17e38
commit 0f41ccc472
2 changed files with 11 additions and 20 deletions

View file

@ -19,8 +19,9 @@ import (
)
const (
connectionTimeout = 5 * time.Second
requestTimeout = 30 * time.Second
serverConnectionTimeout = 5 * time.Second
requestTimeout = 30 * time.Second
portConnectionTimeout = 30 * time.Second
)
const (
@ -57,14 +58,17 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie
}()
// Ping the port to ensure that it is fully forwarded before continuing
liveshare.WaitForPortConnection(ctx, localAddress)
err = liveshare.WaitForPortConnection(ctx, localAddress, portConnectionTimeout)
if err != nil {
return nil, fmt.Errorf("failed to connect to local port: %w", err)
}
// Attempt to connect to the port
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
}
ctx, _ = context.WithTimeout(ctx, connectionTimeout)
ctx, _ = context.WithTimeout(ctx, serverConnectionTimeout)
conn, err := grpc.DialContext(ctx, localAddress, opts...)
if err != nil {
return nil, err

View file

@ -12,10 +12,6 @@ import (
"golang.org/x/sync/errgroup"
)
const (
connectionTimeout = 30 * time.Second
)
type portForwardingSession interface {
StartSharing(context.Context, string, int) (ChannelID, error)
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
@ -60,12 +56,6 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List
return err
}
// Ping the port to ensure that it is fully forwarded before continuing
err = WaitForPortConnection(ctx, listen.Addr().String())
if err != nil {
return err
}
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting
@ -112,20 +102,17 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser)
}
// Connects to and pings a given address to ensure that the server is shared and the port is forwarded.
func WaitForPortConnection(ctx context.Context, address string) error {
waitCtx, cancel := context.WithTimeout(ctx, connectionTimeout)
func WaitForPortConnection(ctx context.Context, address string, timeout time.Duration) error {
waitCtx, cancel := context.WithTimeout(ctx, timeout)
g, waitCtx := errgroup.WithContext(waitCtx)
defer cancel()
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
g.Go(func() error {
for {
select {
case <-waitCtx.Done():
return fmt.Errorf("timed out waiting for connection")
case <-ticker.C:
default:
// Verify that the port can be connected to
conn, err := net.Dial("tcp", address)
if err != nil {