From 581b6652e258bc7487ac85ab49919a3a3c365cf9 Mon Sep 17 00:00:00 2001 From: David Gardiner Date: Tue, 31 Oct 2023 12:19:14 -0700 Subject: [PATCH 1/5] Resolve race condition in codespaces connection --- internal/codespaces/connection/connection.go | 48 +++++++++++++++++-- .../codespaces/connection/connection_test.go | 6 +++ .../portforwarder/port_forwarder.go | 17 ++----- .../codespaces/rpc/test/port_forwarder.go | 4 +- internal/codespaces/states.go | 1 + pkg/cmd/codespace/jupyter.go | 1 + pkg/cmd/codespace/logs.go | 1 + pkg/cmd/codespace/ports.go | 3 ++ pkg/cmd/codespace/rebuild.go | 1 + pkg/cmd/codespace/ssh.go | 5 +- 10 files changed, 67 insertions(+), 20 deletions(-) diff --git a/internal/codespaces/connection/connection.go b/internal/codespaces/connection/connection.go index 5eea89a76..09ca43b48 100644 --- a/internal/codespaces/connection/connection.go +++ b/internal/codespaces/connection/connection.go @@ -16,10 +16,15 @@ const ( clientName = "gh" ) +type TunnelClient struct { + *tunnels.Client + connected bool +} + type CodespaceConnection struct { tunnelProperties api.TunnelProperties TunnelManager *tunnels.Manager - TunnelClient *tunnels.Client + TunnelClient *TunnelClient Options *tunnels.TunnelRequestOptions Tunnel *tunnels.Tunnel AllowedPortPrivacySettings []string @@ -74,6 +79,38 @@ func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpC }, nil } +// Connect connects the client to the tunnel. +func (c *CodespaceConnection) Connect(ctx context.Context) error { + // If already connected, return + if c.TunnelClient.connected { + return nil + } + + // Connect to the tunnel + if err := c.TunnelClient.Client.Connect(ctx, ""); err != nil { + return fmt.Errorf("error connecting to tunnel: %w", err) + } + + // Set the connected flag so we know we're connected + c.TunnelClient.connected = true + + return nil +} + +// Close closes the underlying tunnel client SSH connection. +func (c *CodespaceConnection) Close() error { + // Don't close if we're not connected + if c.TunnelClient != nil && c.TunnelClient.connected { + if err := c.TunnelClient.Close(); err != nil { + return fmt.Errorf("failed to close tunnel client connection: %w", err) + } + + c.TunnelClient.connected = false + } + + return nil +} + // getTunnelManager creates a tunnel manager for the given codespace. // The tunnel manager is used to get the tunnel hosted in the codespace that we // want to connect to and perform operations on ports (add, remove, list, etc.). @@ -96,7 +133,7 @@ func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Cl // getTunnelClient creates a tunnel client for the given tunnel. // The tunnel client is used to connect to the the tunnel and allows // for ports to be forwarded locally. -func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *tunnels.Client, err error) { +func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *TunnelClient, err error) { // Get the tunnel that we want to connect to codespaceTunnel, err := tunnelManager.GetTunnel(ctx, tunnel, options) if err != nil { @@ -107,10 +144,15 @@ func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel codespaceTunnel.AccessTokens = tunnel.AccessTokens // We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports - tunnelClient, err = tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false) + client, err := tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false) if err != nil { return nil, fmt.Errorf("error creating tunnel client: %w", err) } + tunnelClient = &TunnelClient{ + Client: client, + connected: false, + } + return tunnelClient, nil } diff --git a/internal/codespaces/connection/connection_test.go b/internal/codespaces/connection/connection_test.go index e7ebd2788..a444b8cc6 100644 --- a/internal/codespaces/connection/connection_test.go +++ b/internal/codespaces/connection/connection_test.go @@ -41,6 +41,12 @@ func TestNewCodespaceConnection(t *testing.T) { t.Fatalf("NewCodespaceConnection returned an error: %v", err) } + // Verify closing before connected doesn't throw + err = conn.Close() + if err != nil { + t.Fatalf("Close returned an error: %v", err) + } + // Check that the connection was created successfully if conn == nil { t.Fatal("NewCodespaceConnection returned nil") diff --git a/internal/codespaces/portforwarder/port_forwarder.go b/internal/codespaces/portforwarder/port_forwarder.go index 44838a6be..b62d13715 100644 --- a/internal/codespaces/portforwarder/port_forwarder.go +++ b/internal/codespaces/portforwarder/port_forwarder.go @@ -48,7 +48,7 @@ type PortForwarder interface { UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error KeepAlive(reason string) GetKeepAliveReason() string - CloseSSHConnection() + Close() error } // NewPortForwarder returns a new PortForwarder for the specified codespace. @@ -66,9 +66,6 @@ func (fwd *CodespacesPortForwarder) ForwardPortToListener(ctx context.Context, o return fmt.Errorf("error forwarding port: %w", err) } - // Close the SSH connection when we're done - defer fwd.CloseSSHConnection() - done := make(chan error) go func() { // Convert the port number to a uint16 @@ -151,7 +148,7 @@ func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts Forwar } // Connect to the tunnel - err = fwd.connection.TunnelClient.Connect(ctx, "") + err = fwd.connection.Connect(ctx) if err != nil { return fmt.Errorf("connect failed: %v", err) } @@ -159,7 +156,6 @@ func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts Forwar // Inform the host that we've forwarded the port locally err = fwd.connection.TunnelClient.RefreshPorts(ctx) if err != nil { - fwd.CloseSSHConnection() return fmt.Errorf("refresh ports failed: %v", err) } @@ -257,15 +253,12 @@ func (fwd *CodespacesPortForwarder) UpdatePortVisibility(ctx context.Context, re done := make(chan error) go func() { // Connect to the tunnel - err = fwd.connection.TunnelClient.Connect(ctx, "") + err = fwd.connection.Connect(ctx) if err != nil { done <- fmt.Errorf("connect failed: %v", err) return } - // Close the SSH connection when we're done - defer fwd.CloseSSHConnection() - // Inform the host that we've deleted the port err = fwd.connection.TunnelClient.RefreshPorts(ctx) if err != nil { @@ -316,8 +309,8 @@ func (fwd *CodespacesPortForwarder) GetKeepAliveReason() string { } // Close closes the port forwarder's tunnel client connection. -func (fwd *CodespacesPortForwarder) CloseSSHConnection() { - _ = fwd.connection.TunnelClient.Close() +func (fwd *CodespacesPortForwarder) Close() error { + return fwd.connection.Close() } // AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value. diff --git a/internal/codespaces/rpc/test/port_forwarder.go b/internal/codespaces/rpc/test/port_forwarder.go index 4993ac0a1..6930988cc 100644 --- a/internal/codespaces/rpc/test/port_forwarder.go +++ b/internal/codespaces/rpc/test/port_forwarder.go @@ -13,8 +13,8 @@ import ( type PortForwarder struct{} // Close implements portforwarder.PortForwarder. -func (PortForwarder) CloseSSHConnection() { - panic("unimplemented") +func (PortForwarder) Close() error { + return nil } // ConnectToForwardedPort implements portforwarder.PortForwarder. diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index afbdf4673..ab673dd2c 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -47,6 +47,7 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) // Ensure local port is listening before client (getPostCreateOutput) connects. listen, localPort, err := ListenTCP(0, false) diff --git a/pkg/cmd/codespace/jupyter.go b/pkg/cmd/codespace/jupyter.go index 3546837fa..a27a34271 100644 --- a/pkg/cmd/codespace/jupyter.go +++ b/pkg/cmd/codespace/jupyter.go @@ -48,6 +48,7 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) var ( invoker rpc.Invoker diff --git a/pkg/cmd/codespace/logs.go b/pkg/cmd/codespace/logs.go index 13d5ce185..37b301217 100644 --- a/pkg/cmd/codespace/logs.go +++ b/pkg/cmd/codespace/logs.go @@ -51,6 +51,7 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) // Ensure local port is listening before client (getPostCreateOutput) connects. listen, localPort, err := codespaces.ListenTCP(0, false) diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index 3903da6cb..0608f0732 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -66,6 +66,7 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) var ports []*tunnels.TunnelPort err = a.RunWithProgress("Fetching ports", func() (err error) { @@ -246,6 +247,7 @@ func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelec if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) // TODO: check if port visibility can be updated in parallel instead of sequentially for _, port := range ports { @@ -337,6 +339,7 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) opts := portforwarder.ForwardPortOpts{ Port: pair.remote, diff --git a/pkg/cmd/codespace/rebuild.go b/pkg/cmd/codespace/rebuild.go index 464c58502..93d43bf29 100644 --- a/pkg/cmd/codespace/rebuild.go +++ b/pkg/cmd/codespace/rebuild.go @@ -60,6 +60,7 @@ func (a *App) Rebuild(ctx context.Context, selector *CodespaceSelector, full boo if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) invoker, err := rpc.CreateInvoker(ctx, fwd) if err != nil { diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 42571b52e..384b92bb8 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -202,6 +202,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) var ( invoker rpc.Invoker @@ -238,9 +239,6 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e return fmt.Errorf("failed to forward port: %w", err) } - // Close the SSH connection when we're done - defer fwd.CloseSSHConnection() - // Connect to the forwarded port err = fwd.ConnectToForwardedPort(ctx, stdio, opts) if err != nil { @@ -584,6 +582,7 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro sshUsers <- result return } + defer safeClose(fwd, &err) invoker, err := rpc.CreateInvoker(ctx, fwd) if err != nil { From b566ea670c8a87561ae865e630b485bab39507ff Mon Sep 17 00:00:00 2001 From: David Gardiner Date: Tue, 31 Oct 2023 13:01:24 -0700 Subject: [PATCH 2/5] Add mutex for connect --- internal/codespaces/connection/connection.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/internal/codespaces/connection/connection.go b/internal/codespaces/connection/connection.go index 09ca43b48..eaffa1862 100644 --- a/internal/codespaces/connection/connection.go +++ b/internal/codespaces/connection/connection.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "net/url" + "sync" "github.com/cli/cli/v2/internal/codespaces/api" "github.com/microsoft/dev-tunnels/go/tunnels" @@ -19,6 +20,7 @@ const ( type TunnelClient struct { *tunnels.Client connected bool + connectMu sync.Mutex } type CodespaceConnection struct { @@ -81,6 +83,9 @@ func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpC // Connect connects the client to the tunnel. func (c *CodespaceConnection) Connect(ctx context.Context) error { + // Lock the mutex to prevent connection races + c.TunnelClient.connectMu.Lock() + // If already connected, return if c.TunnelClient.connected { return nil @@ -94,6 +99,9 @@ func (c *CodespaceConnection) Connect(ctx context.Context) error { // Set the connected flag so we know we're connected c.TunnelClient.connected = true + // Unlock the mutex + c.TunnelClient.connectMu.Unlock() + return nil } From 1bcf92438a48872a532033d09d04f55a9ee2ebc6 Mon Sep 17 00:00:00 2001 From: David Gardiner Date: Tue, 31 Oct 2023 13:08:34 -0700 Subject: [PATCH 3/5] Defer the mutex unlock --- internal/codespaces/connection/connection.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/internal/codespaces/connection/connection.go b/internal/codespaces/connection/connection.go index eaffa1862..320817cca 100644 --- a/internal/codespaces/connection/connection.go +++ b/internal/codespaces/connection/connection.go @@ -85,6 +85,7 @@ func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpC func (c *CodespaceConnection) Connect(ctx context.Context) error { // Lock the mutex to prevent connection races c.TunnelClient.connectMu.Lock() + defer c.TunnelClient.connectMu.Unlock() // If already connected, return if c.TunnelClient.connected { @@ -99,9 +100,6 @@ func (c *CodespaceConnection) Connect(ctx context.Context) error { // Set the connected flag so we know we're connected c.TunnelClient.connected = true - // Unlock the mutex - c.TunnelClient.connectMu.Unlock() - return nil } From d04a9d941f1155a8a3a70ca1bfe80fe6e9a59988 Mon Sep 17 00:00:00 2001 From: David Gardiner Date: Tue, 31 Oct 2023 13:11:28 -0700 Subject: [PATCH 4/5] Lock the `Close` func --- internal/codespaces/connection/connection.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/internal/codespaces/connection/connection.go b/internal/codespaces/connection/connection.go index 320817cca..2528b28ed 100644 --- a/internal/codespaces/connection/connection.go +++ b/internal/codespaces/connection/connection.go @@ -20,7 +20,7 @@ const ( type TunnelClient struct { *tunnels.Client connected bool - connectMu sync.Mutex + mu sync.Mutex } type CodespaceConnection struct { @@ -84,8 +84,8 @@ func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpC // Connect connects the client to the tunnel. func (c *CodespaceConnection) Connect(ctx context.Context) error { // Lock the mutex to prevent connection races - c.TunnelClient.connectMu.Lock() - defer c.TunnelClient.connectMu.Unlock() + c.TunnelClient.mu.Lock() + defer c.TunnelClient.mu.Unlock() // If already connected, return if c.TunnelClient.connected { @@ -105,6 +105,10 @@ func (c *CodespaceConnection) Connect(ctx context.Context) error { // Close closes the underlying tunnel client SSH connection. func (c *CodespaceConnection) Close() error { + // Lock the mutex to prevent connection races + c.TunnelClient.mu.Lock() + defer c.TunnelClient.mu.Unlock() + // Don't close if we're not connected if c.TunnelClient != nil && c.TunnelClient.connected { if err := c.TunnelClient.Close(); err != nil { From d22c6f33e9c828e89bf5ecd4abf4b5436e7306cd Mon Sep 17 00:00:00 2001 From: David Gardiner Date: Tue, 31 Oct 2023 13:12:25 -0700 Subject: [PATCH 5/5] Update comment --- internal/codespaces/connection/connection.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/codespaces/connection/connection.go b/internal/codespaces/connection/connection.go index 2528b28ed..7eda95e50 100644 --- a/internal/codespaces/connection/connection.go +++ b/internal/codespaces/connection/connection.go @@ -83,7 +83,7 @@ func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpC // Connect connects the client to the tunnel. func (c *CodespaceConnection) Connect(ctx context.Context) error { - // Lock the mutex to prevent connection races + // Lock the mutex to prevent race conditions with the underlying SSH connection c.TunnelClient.mu.Lock() defer c.TunnelClient.mu.Unlock() @@ -105,7 +105,7 @@ func (c *CodespaceConnection) Connect(ctx context.Context) error { // Close closes the underlying tunnel client SSH connection. func (c *CodespaceConnection) Close() error { - // Lock the mutex to prevent connection races + // Lock the mutex to prevent race conditions with the underlying SSH connection c.TunnelClient.mu.Lock() defer c.TunnelClient.mu.Unlock()