Resolve race condition in codespaces connection

This commit is contained in:
David Gardiner 2023-10-31 12:19:14 -07:00
parent ede1705bf2
commit 581b6652e2
10 changed files with 67 additions and 20 deletions

View file

@ -16,10 +16,15 @@ const (
clientName = "gh"
)
type TunnelClient struct {
*tunnels.Client
connected bool
}
type CodespaceConnection struct {
tunnelProperties api.TunnelProperties
TunnelManager *tunnels.Manager
TunnelClient *tunnels.Client
TunnelClient *TunnelClient
Options *tunnels.TunnelRequestOptions
Tunnel *tunnels.Tunnel
AllowedPortPrivacySettings []string
@ -74,6 +79,38 @@ func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpC
}, nil
}
// Connect connects the client to the tunnel.
func (c *CodespaceConnection) Connect(ctx context.Context) error {
// If already connected, return
if c.TunnelClient.connected {
return nil
}
// Connect to the tunnel
if err := c.TunnelClient.Client.Connect(ctx, ""); err != nil {
return fmt.Errorf("error connecting to tunnel: %w", err)
}
// Set the connected flag so we know we're connected
c.TunnelClient.connected = true
return nil
}
// Close closes the underlying tunnel client SSH connection.
func (c *CodespaceConnection) Close() error {
// Don't close if we're not connected
if c.TunnelClient != nil && c.TunnelClient.connected {
if err := c.TunnelClient.Close(); err != nil {
return fmt.Errorf("failed to close tunnel client connection: %w", err)
}
c.TunnelClient.connected = false
}
return nil
}
// getTunnelManager creates a tunnel manager for the given codespace.
// The tunnel manager is used to get the tunnel hosted in the codespace that we
// want to connect to and perform operations on ports (add, remove, list, etc.).
@ -96,7 +133,7 @@ func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Cl
// getTunnelClient creates a tunnel client for the given tunnel.
// The tunnel client is used to connect to the the tunnel and allows
// for ports to be forwarded locally.
func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *tunnels.Client, err error) {
func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *TunnelClient, err error) {
// Get the tunnel that we want to connect to
codespaceTunnel, err := tunnelManager.GetTunnel(ctx, tunnel, options)
if err != nil {
@ -107,10 +144,15 @@ func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel
codespaceTunnel.AccessTokens = tunnel.AccessTokens
// We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports
tunnelClient, err = tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false)
client, err := tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false)
if err != nil {
return nil, fmt.Errorf("error creating tunnel client: %w", err)
}
tunnelClient = &TunnelClient{
Client: client,
connected: false,
}
return tunnelClient, nil
}

View file

@ -41,6 +41,12 @@ func TestNewCodespaceConnection(t *testing.T) {
t.Fatalf("NewCodespaceConnection returned an error: %v", err)
}
// Verify closing before connected doesn't throw
err = conn.Close()
if err != nil {
t.Fatalf("Close returned an error: %v", err)
}
// Check that the connection was created successfully
if conn == nil {
t.Fatal("NewCodespaceConnection returned nil")

View file

@ -48,7 +48,7 @@ type PortForwarder interface {
UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error
KeepAlive(reason string)
GetKeepAliveReason() string
CloseSSHConnection()
Close() error
}
// NewPortForwarder returns a new PortForwarder for the specified codespace.
@ -66,9 +66,6 @@ func (fwd *CodespacesPortForwarder) ForwardPortToListener(ctx context.Context, o
return fmt.Errorf("error forwarding port: %w", err)
}
// Close the SSH connection when we're done
defer fwd.CloseSSHConnection()
done := make(chan error)
go func() {
// Convert the port number to a uint16
@ -151,7 +148,7 @@ func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts Forwar
}
// Connect to the tunnel
err = fwd.connection.TunnelClient.Connect(ctx, "")
err = fwd.connection.Connect(ctx)
if err != nil {
return fmt.Errorf("connect failed: %v", err)
}
@ -159,7 +156,6 @@ func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts Forwar
// Inform the host that we've forwarded the port locally
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
if err != nil {
fwd.CloseSSHConnection()
return fmt.Errorf("refresh ports failed: %v", err)
}
@ -257,15 +253,12 @@ func (fwd *CodespacesPortForwarder) UpdatePortVisibility(ctx context.Context, re
done := make(chan error)
go func() {
// Connect to the tunnel
err = fwd.connection.TunnelClient.Connect(ctx, "")
err = fwd.connection.Connect(ctx)
if err != nil {
done <- fmt.Errorf("connect failed: %v", err)
return
}
// Close the SSH connection when we're done
defer fwd.CloseSSHConnection()
// Inform the host that we've deleted the port
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
if err != nil {
@ -316,8 +309,8 @@ func (fwd *CodespacesPortForwarder) GetKeepAliveReason() string {
}
// Close closes the port forwarder's tunnel client connection.
func (fwd *CodespacesPortForwarder) CloseSSHConnection() {
_ = fwd.connection.TunnelClient.Close()
func (fwd *CodespacesPortForwarder) Close() error {
return fwd.connection.Close()
}
// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value.

View file

@ -13,8 +13,8 @@ import (
type PortForwarder struct{}
// Close implements portforwarder.PortForwarder.
func (PortForwarder) CloseSSHConnection() {
panic("unimplemented")
func (PortForwarder) Close() error {
return nil
}
// ConnectToForwardedPort implements portforwarder.PortForwarder.

View file

@ -47,6 +47,7 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer safeClose(fwd, &err)
// Ensure local port is listening before client (getPostCreateOutput) connects.
listen, localPort, err := ListenTCP(0, false)

View file

@ -48,6 +48,7 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer safeClose(fwd, &err)
var (
invoker rpc.Invoker

View file

@ -51,6 +51,7 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer safeClose(fwd, &err)
// Ensure local port is listening before client (getPostCreateOutput) connects.
listen, localPort, err := codespaces.ListenTCP(0, false)

View file

@ -66,6 +66,7 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer safeClose(fwd, &err)
var ports []*tunnels.TunnelPort
err = a.RunWithProgress("Fetching ports", func() (err error) {
@ -246,6 +247,7 @@ func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelec
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer safeClose(fwd, &err)
// TODO: check if port visibility can be updated in parallel instead of sequentially
for _, port := range ports {
@ -337,6 +339,7 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer safeClose(fwd, &err)
opts := portforwarder.ForwardPortOpts{
Port: pair.remote,

View file

@ -60,6 +60,7 @@ func (a *App) Rebuild(ctx context.Context, selector *CodespaceSelector, full boo
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer safeClose(fwd, &err)
invoker, err := rpc.CreateInvoker(ctx, fwd)
if err != nil {

View file

@ -202,6 +202,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer safeClose(fwd, &err)
var (
invoker rpc.Invoker
@ -238,9 +239,6 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
return fmt.Errorf("failed to forward port: %w", err)
}
// Close the SSH connection when we're done
defer fwd.CloseSSHConnection()
// Connect to the forwarded port
err = fwd.ConnectToForwardedPort(ctx, stdio, opts)
if err != nil {
@ -584,6 +582,7 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
sshUsers <- result
return
}
defer safeClose(fwd, &err)
invoker, err := rpc.CreateInvoker(ctx, fwd)
if err != nil {