diff --git a/internal/codespaces/connection/connection.go b/internal/codespaces/connection/connection.go index 5eea89a76..09ca43b48 100644 --- a/internal/codespaces/connection/connection.go +++ b/internal/codespaces/connection/connection.go @@ -16,10 +16,15 @@ const ( clientName = "gh" ) +type TunnelClient struct { + *tunnels.Client + connected bool +} + type CodespaceConnection struct { tunnelProperties api.TunnelProperties TunnelManager *tunnels.Manager - TunnelClient *tunnels.Client + TunnelClient *TunnelClient Options *tunnels.TunnelRequestOptions Tunnel *tunnels.Tunnel AllowedPortPrivacySettings []string @@ -74,6 +79,38 @@ func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpC }, nil } +// Connect connects the client to the tunnel. +func (c *CodespaceConnection) Connect(ctx context.Context) error { + // If already connected, return + if c.TunnelClient.connected { + return nil + } + + // Connect to the tunnel + if err := c.TunnelClient.Client.Connect(ctx, ""); err != nil { + return fmt.Errorf("error connecting to tunnel: %w", err) + } + + // Set the connected flag so we know we're connected + c.TunnelClient.connected = true + + return nil +} + +// Close closes the underlying tunnel client SSH connection. +func (c *CodespaceConnection) Close() error { + // Don't close if we're not connected + if c.TunnelClient != nil && c.TunnelClient.connected { + if err := c.TunnelClient.Close(); err != nil { + return fmt.Errorf("failed to close tunnel client connection: %w", err) + } + + c.TunnelClient.connected = false + } + + return nil +} + // getTunnelManager creates a tunnel manager for the given codespace. // The tunnel manager is used to get the tunnel hosted in the codespace that we // want to connect to and perform operations on ports (add, remove, list, etc.). @@ -96,7 +133,7 @@ func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Cl // getTunnelClient creates a tunnel client for the given tunnel. // The tunnel client is used to connect to the the tunnel and allows // for ports to be forwarded locally. -func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *tunnels.Client, err error) { +func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *TunnelClient, err error) { // Get the tunnel that we want to connect to codespaceTunnel, err := tunnelManager.GetTunnel(ctx, tunnel, options) if err != nil { @@ -107,10 +144,15 @@ func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel codespaceTunnel.AccessTokens = tunnel.AccessTokens // We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports - tunnelClient, err = tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false) + client, err := tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false) if err != nil { return nil, fmt.Errorf("error creating tunnel client: %w", err) } + tunnelClient = &TunnelClient{ + Client: client, + connected: false, + } + return tunnelClient, nil } diff --git a/internal/codespaces/connection/connection_test.go b/internal/codespaces/connection/connection_test.go index e7ebd2788..a444b8cc6 100644 --- a/internal/codespaces/connection/connection_test.go +++ b/internal/codespaces/connection/connection_test.go @@ -41,6 +41,12 @@ func TestNewCodespaceConnection(t *testing.T) { t.Fatalf("NewCodespaceConnection returned an error: %v", err) } + // Verify closing before connected doesn't throw + err = conn.Close() + if err != nil { + t.Fatalf("Close returned an error: %v", err) + } + // Check that the connection was created successfully if conn == nil { t.Fatal("NewCodespaceConnection returned nil") diff --git a/internal/codespaces/portforwarder/port_forwarder.go b/internal/codespaces/portforwarder/port_forwarder.go index 44838a6be..b62d13715 100644 --- a/internal/codespaces/portforwarder/port_forwarder.go +++ b/internal/codespaces/portforwarder/port_forwarder.go @@ -48,7 +48,7 @@ type PortForwarder interface { UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error KeepAlive(reason string) GetKeepAliveReason() string - CloseSSHConnection() + Close() error } // NewPortForwarder returns a new PortForwarder for the specified codespace. @@ -66,9 +66,6 @@ func (fwd *CodespacesPortForwarder) ForwardPortToListener(ctx context.Context, o return fmt.Errorf("error forwarding port: %w", err) } - // Close the SSH connection when we're done - defer fwd.CloseSSHConnection() - done := make(chan error) go func() { // Convert the port number to a uint16 @@ -151,7 +148,7 @@ func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts Forwar } // Connect to the tunnel - err = fwd.connection.TunnelClient.Connect(ctx, "") + err = fwd.connection.Connect(ctx) if err != nil { return fmt.Errorf("connect failed: %v", err) } @@ -159,7 +156,6 @@ func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts Forwar // Inform the host that we've forwarded the port locally err = fwd.connection.TunnelClient.RefreshPorts(ctx) if err != nil { - fwd.CloseSSHConnection() return fmt.Errorf("refresh ports failed: %v", err) } @@ -257,15 +253,12 @@ func (fwd *CodespacesPortForwarder) UpdatePortVisibility(ctx context.Context, re done := make(chan error) go func() { // Connect to the tunnel - err = fwd.connection.TunnelClient.Connect(ctx, "") + err = fwd.connection.Connect(ctx) if err != nil { done <- fmt.Errorf("connect failed: %v", err) return } - // Close the SSH connection when we're done - defer fwd.CloseSSHConnection() - // Inform the host that we've deleted the port err = fwd.connection.TunnelClient.RefreshPorts(ctx) if err != nil { @@ -316,8 +309,8 @@ func (fwd *CodespacesPortForwarder) GetKeepAliveReason() string { } // Close closes the port forwarder's tunnel client connection. -func (fwd *CodespacesPortForwarder) CloseSSHConnection() { - _ = fwd.connection.TunnelClient.Close() +func (fwd *CodespacesPortForwarder) Close() error { + return fwd.connection.Close() } // AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value. diff --git a/internal/codespaces/rpc/test/port_forwarder.go b/internal/codespaces/rpc/test/port_forwarder.go index 4993ac0a1..6930988cc 100644 --- a/internal/codespaces/rpc/test/port_forwarder.go +++ b/internal/codespaces/rpc/test/port_forwarder.go @@ -13,8 +13,8 @@ import ( type PortForwarder struct{} // Close implements portforwarder.PortForwarder. -func (PortForwarder) CloseSSHConnection() { - panic("unimplemented") +func (PortForwarder) Close() error { + return nil } // ConnectToForwardedPort implements portforwarder.PortForwarder. diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index afbdf4673..ab673dd2c 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -47,6 +47,7 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) // Ensure local port is listening before client (getPostCreateOutput) connects. listen, localPort, err := ListenTCP(0, false) diff --git a/pkg/cmd/codespace/jupyter.go b/pkg/cmd/codespace/jupyter.go index 3546837fa..a27a34271 100644 --- a/pkg/cmd/codespace/jupyter.go +++ b/pkg/cmd/codespace/jupyter.go @@ -48,6 +48,7 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) var ( invoker rpc.Invoker diff --git a/pkg/cmd/codespace/logs.go b/pkg/cmd/codespace/logs.go index 13d5ce185..37b301217 100644 --- a/pkg/cmd/codespace/logs.go +++ b/pkg/cmd/codespace/logs.go @@ -51,6 +51,7 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) // Ensure local port is listening before client (getPostCreateOutput) connects. listen, localPort, err := codespaces.ListenTCP(0, false) diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index 3903da6cb..0608f0732 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -66,6 +66,7 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) var ports []*tunnels.TunnelPort err = a.RunWithProgress("Fetching ports", func() (err error) { @@ -246,6 +247,7 @@ func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelec if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) // TODO: check if port visibility can be updated in parallel instead of sequentially for _, port := range ports { @@ -337,6 +339,7 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) opts := portforwarder.ForwardPortOpts{ Port: pair.remote, diff --git a/pkg/cmd/codespace/rebuild.go b/pkg/cmd/codespace/rebuild.go index 464c58502..93d43bf29 100644 --- a/pkg/cmd/codespace/rebuild.go +++ b/pkg/cmd/codespace/rebuild.go @@ -60,6 +60,7 @@ func (a *App) Rebuild(ctx context.Context, selector *CodespaceSelector, full boo if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) invoker, err := rpc.CreateInvoker(ctx, fwd) if err != nil { diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 42571b52e..384b92bb8 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -202,6 +202,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } + defer safeClose(fwd, &err) var ( invoker rpc.Invoker @@ -238,9 +239,6 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e return fmt.Errorf("failed to forward port: %w", err) } - // Close the SSH connection when we're done - defer fwd.CloseSSHConnection() - // Connect to the forwarded port err = fwd.ConnectToForwardedPort(ctx, stdio, opts) if err != nil { @@ -584,6 +582,7 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro sshUsers <- result return } + defer safeClose(fwd, &err) invoker, err := rpc.CreateInvoker(ctx, fwd) if err != nil {