Have cilents call port connection function
This commit is contained in:
parent
a090b17e38
commit
0f41ccc472
2 changed files with 11 additions and 20 deletions
|
|
@ -19,8 +19,9 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
connectionTimeout = 5 * time.Second
|
||||
requestTimeout = 30 * time.Second
|
||||
serverConnectionTimeout = 5 * time.Second
|
||||
requestTimeout = 30 * time.Second
|
||||
portConnectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -57,14 +58,17 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie
|
|||
}()
|
||||
|
||||
// Ping the port to ensure that it is fully forwarded before continuing
|
||||
liveshare.WaitForPortConnection(ctx, localAddress)
|
||||
err = liveshare.WaitForPortConnection(ctx, localAddress, portConnectionTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to local port: %w", err)
|
||||
}
|
||||
|
||||
// Attempt to connect to the port
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
}
|
||||
ctx, _ = context.WithTimeout(ctx, connectionTimeout)
|
||||
ctx, _ = context.WithTimeout(ctx, serverConnectionTimeout)
|
||||
conn, err := grpc.DialContext(ctx, localAddress, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -12,10 +12,6 @@ import (
|
|||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
connectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type portForwardingSession interface {
|
||||
StartSharing(context.Context, string, int) (ChannelID, error)
|
||||
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
|
||||
|
|
@ -60,12 +56,6 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List
|
|||
return err
|
||||
}
|
||||
|
||||
// Ping the port to ensure that it is fully forwarded before continuing
|
||||
err = WaitForPortConnection(ctx, listen.Addr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
errc := make(chan error, 1)
|
||||
sendError := func(err error) {
|
||||
// Use non-blocking send, to avoid goroutines getting
|
||||
|
|
@ -112,20 +102,17 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser)
|
|||
}
|
||||
|
||||
// Connects to and pings a given address to ensure that the server is shared and the port is forwarded.
|
||||
func WaitForPortConnection(ctx context.Context, address string) error {
|
||||
waitCtx, cancel := context.WithTimeout(ctx, connectionTimeout)
|
||||
func WaitForPortConnection(ctx context.Context, address string, timeout time.Duration) error {
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
g, waitCtx := errgroup.WithContext(waitCtx)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-waitCtx.Done():
|
||||
return fmt.Errorf("timed out waiting for connection")
|
||||
case <-ticker.C:
|
||||
default:
|
||||
// Verify that the port can be connected to
|
||||
conn, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue