Resolve race condition in codespaces connection
This commit is contained in:
parent
ede1705bf2
commit
581b6652e2
10 changed files with 67 additions and 20 deletions
|
|
@ -16,10 +16,15 @@ const (
|
|||
clientName = "gh"
|
||||
)
|
||||
|
||||
type TunnelClient struct {
|
||||
*tunnels.Client
|
||||
connected bool
|
||||
}
|
||||
|
||||
type CodespaceConnection struct {
|
||||
tunnelProperties api.TunnelProperties
|
||||
TunnelManager *tunnels.Manager
|
||||
TunnelClient *tunnels.Client
|
||||
TunnelClient *TunnelClient
|
||||
Options *tunnels.TunnelRequestOptions
|
||||
Tunnel *tunnels.Tunnel
|
||||
AllowedPortPrivacySettings []string
|
||||
|
|
@ -74,6 +79,38 @@ func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpC
|
|||
}, nil
|
||||
}
|
||||
|
||||
// Connect connects the client to the tunnel.
|
||||
func (c *CodespaceConnection) Connect(ctx context.Context) error {
|
||||
// If already connected, return
|
||||
if c.TunnelClient.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect to the tunnel
|
||||
if err := c.TunnelClient.Client.Connect(ctx, ""); err != nil {
|
||||
return fmt.Errorf("error connecting to tunnel: %w", err)
|
||||
}
|
||||
|
||||
// Set the connected flag so we know we're connected
|
||||
c.TunnelClient.connected = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the underlying tunnel client SSH connection.
|
||||
func (c *CodespaceConnection) Close() error {
|
||||
// Don't close if we're not connected
|
||||
if c.TunnelClient != nil && c.TunnelClient.connected {
|
||||
if err := c.TunnelClient.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close tunnel client connection: %w", err)
|
||||
}
|
||||
|
||||
c.TunnelClient.connected = false
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getTunnelManager creates a tunnel manager for the given codespace.
|
||||
// The tunnel manager is used to get the tunnel hosted in the codespace that we
|
||||
// want to connect to and perform operations on ports (add, remove, list, etc.).
|
||||
|
|
@ -96,7 +133,7 @@ func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Cl
|
|||
// getTunnelClient creates a tunnel client for the given tunnel.
|
||||
// The tunnel client is used to connect to the the tunnel and allows
|
||||
// for ports to be forwarded locally.
|
||||
func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *tunnels.Client, err error) {
|
||||
func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *TunnelClient, err error) {
|
||||
// Get the tunnel that we want to connect to
|
||||
codespaceTunnel, err := tunnelManager.GetTunnel(ctx, tunnel, options)
|
||||
if err != nil {
|
||||
|
|
@ -107,10 +144,15 @@ func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel
|
|||
codespaceTunnel.AccessTokens = tunnel.AccessTokens
|
||||
|
||||
// We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports
|
||||
tunnelClient, err = tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false)
|
||||
client, err := tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating tunnel client: %w", err)
|
||||
}
|
||||
|
||||
tunnelClient = &TunnelClient{
|
||||
Client: client,
|
||||
connected: false,
|
||||
}
|
||||
|
||||
return tunnelClient, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,6 +41,12 @@ func TestNewCodespaceConnection(t *testing.T) {
|
|||
t.Fatalf("NewCodespaceConnection returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Verify closing before connected doesn't throw
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Close returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Check that the connection was created successfully
|
||||
if conn == nil {
|
||||
t.Fatal("NewCodespaceConnection returned nil")
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ type PortForwarder interface {
|
|||
UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error
|
||||
KeepAlive(reason string)
|
||||
GetKeepAliveReason() string
|
||||
CloseSSHConnection()
|
||||
Close() error
|
||||
}
|
||||
|
||||
// NewPortForwarder returns a new PortForwarder for the specified codespace.
|
||||
|
|
@ -66,9 +66,6 @@ func (fwd *CodespacesPortForwarder) ForwardPortToListener(ctx context.Context, o
|
|||
return fmt.Errorf("error forwarding port: %w", err)
|
||||
}
|
||||
|
||||
// Close the SSH connection when we're done
|
||||
defer fwd.CloseSSHConnection()
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
// Convert the port number to a uint16
|
||||
|
|
@ -151,7 +148,7 @@ func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts Forwar
|
|||
}
|
||||
|
||||
// Connect to the tunnel
|
||||
err = fwd.connection.TunnelClient.Connect(ctx, "")
|
||||
err = fwd.connection.Connect(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect failed: %v", err)
|
||||
}
|
||||
|
|
@ -159,7 +156,6 @@ func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts Forwar
|
|||
// Inform the host that we've forwarded the port locally
|
||||
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
|
||||
if err != nil {
|
||||
fwd.CloseSSHConnection()
|
||||
return fmt.Errorf("refresh ports failed: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -257,15 +253,12 @@ func (fwd *CodespacesPortForwarder) UpdatePortVisibility(ctx context.Context, re
|
|||
done := make(chan error)
|
||||
go func() {
|
||||
// Connect to the tunnel
|
||||
err = fwd.connection.TunnelClient.Connect(ctx, "")
|
||||
err = fwd.connection.Connect(ctx)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("connect failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Close the SSH connection when we're done
|
||||
defer fwd.CloseSSHConnection()
|
||||
|
||||
// Inform the host that we've deleted the port
|
||||
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -316,8 +309,8 @@ func (fwd *CodespacesPortForwarder) GetKeepAliveReason() string {
|
|||
}
|
||||
|
||||
// Close closes the port forwarder's tunnel client connection.
|
||||
func (fwd *CodespacesPortForwarder) CloseSSHConnection() {
|
||||
_ = fwd.connection.TunnelClient.Close()
|
||||
func (fwd *CodespacesPortForwarder) Close() error {
|
||||
return fwd.connection.Close()
|
||||
}
|
||||
|
||||
// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value.
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ import (
|
|||
type PortForwarder struct{}
|
||||
|
||||
// Close implements portforwarder.PortForwarder.
|
||||
func (PortForwarder) CloseSSHConnection() {
|
||||
panic("unimplemented")
|
||||
func (PortForwarder) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectToForwardedPort implements portforwarder.PortForwarder.
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer safeClose(fwd, &err)
|
||||
|
||||
// Ensure local port is listening before client (getPostCreateOutput) connects.
|
||||
listen, localPort, err := ListenTCP(0, false)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer safeClose(fwd, &err)
|
||||
|
||||
var (
|
||||
invoker rpc.Invoker
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer safeClose(fwd, &err)
|
||||
|
||||
// Ensure local port is listening before client (getPostCreateOutput) connects.
|
||||
listen, localPort, err := codespaces.ListenTCP(0, false)
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer safeClose(fwd, &err)
|
||||
|
||||
var ports []*tunnels.TunnelPort
|
||||
err = a.RunWithProgress("Fetching ports", func() (err error) {
|
||||
|
|
@ -246,6 +247,7 @@ func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelec
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer safeClose(fwd, &err)
|
||||
|
||||
// TODO: check if port visibility can be updated in parallel instead of sequentially
|
||||
for _, port := range ports {
|
||||
|
|
@ -337,6 +339,7 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer safeClose(fwd, &err)
|
||||
|
||||
opts := portforwarder.ForwardPortOpts{
|
||||
Port: pair.remote,
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ func (a *App) Rebuild(ctx context.Context, selector *CodespaceSelector, full boo
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer safeClose(fwd, &err)
|
||||
|
||||
invoker, err := rpc.CreateInvoker(ctx, fwd)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -202,6 +202,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer safeClose(fwd, &err)
|
||||
|
||||
var (
|
||||
invoker rpc.Invoker
|
||||
|
|
@ -238,9 +239,6 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
|
|||
return fmt.Errorf("failed to forward port: %w", err)
|
||||
}
|
||||
|
||||
// Close the SSH connection when we're done
|
||||
defer fwd.CloseSSHConnection()
|
||||
|
||||
// Connect to the forwarded port
|
||||
err = fwd.ConnectToForwardedPort(ctx, stdio, opts)
|
||||
if err != nil {
|
||||
|
|
@ -584,6 +582,7 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
|
|||
sshUsers <- result
|
||||
return
|
||||
}
|
||||
defer safeClose(fwd, &err)
|
||||
|
||||
invoker, err := rpc.CreateInvoker(ctx, fwd)
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue