diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index d21f9ff6c..575b95853 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -15,8 +15,10 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/AlecAivazis/survey/v2/terminal" + "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/api" "github.com/cli/cli/v2/pkg/iostreams" + "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" "golang.org/x/term" ) @@ -59,6 +61,36 @@ func (a *App) StopProgressIndicator() { a.io.StopProgressIndicator() } +func startSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (*liveshare.Session, func(*error), error) { + // While connecting, ensure in the background that the user has keys installed. + // That lets us report a more useful error message if they don't. + authkeys := make(chan error, 1) + go func() { + authkeys <- checkAuthorizedKeys(ctx, a.apiClient) + }() + + liveshareLogger := noopLogger() + if debug { + debugLogger, err := newFileLogger(debugFile) + if err != nil { + return nil, nil, fmt.Errorf("error creating debug logger: %w", err) + } + defer safeClose(debugLogger, &err) + + liveshareLogger = debugLogger.Logger + a.errLogger.Printf("Debug file located at: %s", debugLogger.Name()) + } + + session, err := codespaces.ConnectToLiveshare(ctx, a, liveshareLogger, a.apiClient, codespace) + if err != nil { + if authErr := <-authkeys; authErr != nil { + return nil, nil, authErr + } + return nil, nil, fmt.Errorf("error connecting to codespace: %w", err) + } + return session, func(e *error) { safeClose(session, e) }, nil +} + //go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient type apiClient interface { GetUser(ctx context.Context) (*api.User, error) diff --git a/pkg/cmd/codespace/jupyter.go b/pkg/cmd/codespace/jupyter.go index 484ab7577..88f0f0097 100644 --- a/pkg/cmd/codespace/jupyter.go +++ b/pkg/cmd/codespace/jupyter.go @@ -6,7 +6,6 @@ import ( "net" "strings" - "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" ) @@ -15,7 +14,6 @@ type jupyterOptions struct { codespace string debug bool debugFile string - stdio bool } func newJupyterCmd(app *App) *cobra.Command { @@ -34,8 +32,6 @@ func newJupyterCmd(app *App) *cobra.Command { } func (a *App) Jupyter(ctx context.Context, opts jupyterOptions) error { - // TODO: Share liveshare setup code with ssh.go and logs.go - // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -45,33 +41,11 @@ func (a *App) Jupyter(ctx context.Context, opts jupyterOptions) error { return fmt.Errorf("get or choose codespace: %w", err) } - // While connecting, ensure in the background that the user has keys installed. - // That lets us report a more useful error message if they don't. - authkeys := make(chan error, 1) - go func() { - authkeys <- checkAuthorizedKeys(ctx, a.apiClient) - }() - - liveshareLogger := noopLogger() - if opts.debug { - debugLogger, err := newFileLogger(opts.debugFile) - if err != nil { - return fmt.Errorf("error creating debug logger: %w", err) - } - defer safeClose(debugLogger, &err) - - liveshareLogger = debugLogger.Logger - a.errLogger.Printf("Debug file located at: %s", debugLogger.Name()) - } - - session, err := codespaces.ConnectToLiveshare(ctx, a, liveshareLogger, a.apiClient, codespace) + session, closeSession, err := startSession(ctx, codespace, a, opts.debug, opts.debugFile) if err != nil { - if authErr := <-authkeys; authErr != nil { - return authErr - } - return fmt.Errorf("error connecting to codespace: %w", err) + return err } - defer safeClose(session, &err) + defer closeSession(&err) a.StartProgressIndicatorWithLabel("Fetching Jupyter Details") jupyterServerPort, jupyterServerUrl, err := session.StartJupyterServer(ctx) diff --git a/pkg/cmd/codespace/logs.go b/pkg/cmd/codespace/logs.go index d0a0c233b..b1173debe 100644 --- a/pkg/cmd/codespace/logs.go +++ b/pkg/cmd/codespace/logs.go @@ -41,20 +41,11 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err return fmt.Errorf("get or choose codespace: %w", err) } - authkeys := make(chan error, 1) - go func() { - authkeys <- checkAuthorizedKeys(ctx, a.apiClient) - }() - - session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace) + session, closeSession, err := startSession(ctx, codespace, a, false, "") if err != nil { - return fmt.Errorf("connecting to codespace: %w", err) - } - defer safeClose(session, &err) - - if err := <-authkeys; err != nil { return err } + defer closeSession(&err) // Ensure local port is listening before client (getPostCreateOutput) connects. listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index 094833e30..ad60eeecd 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -57,11 +57,11 @@ func (a *App) ListPorts(ctx context.Context, codespaceName string, exporter cmdu devContainerCh := getDevContainer(ctx, a.apiClient, codespace) - session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace) + session, closeSession, err := startSession(ctx, codespace, a, false, "") if err != nil { - return fmt.Errorf("error connecting to codespace: %w", err) + return err } - defer safeClose(session, &err) + defer closeSession(&err) a.StartProgressIndicatorWithLabel("Fetching ports") ports, err := session.GetSharedServers(ctx) diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 726f2152f..b70f33eb1 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -116,38 +116,16 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e ctx, cancel := context.WithCancel(ctx) defer cancel() - // While connecting, ensure in the background that the user has keys installed. - // That lets us report a more useful error message if they don't. - authkeys := make(chan error, 1) - go func() { - authkeys <- checkAuthorizedKeys(ctx, a.apiClient) - }() - codespace, err := getOrChooseCodespace(ctx, a.apiClient, opts.codespace) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) } - liveshareLogger := noopLogger() - if opts.debug { - debugLogger, err := newFileLogger(opts.debugFile) - if err != nil { - return fmt.Errorf("error creating debug logger: %w", err) - } - defer safeClose(debugLogger, &err) - - liveshareLogger = debugLogger.Logger - a.errLogger.Printf("Debug file located at: %s", debugLogger.Name()) - } - - session, err := codespaces.ConnectToLiveshare(ctx, a, liveshareLogger, a.apiClient, codespace) + session, closeSession, err := startSession(ctx, codespace, a, opts.debug, opts.debugFile) if err != nil { - if authErr := <-authkeys; authErr != nil { - return authErr - } - return fmt.Errorf("error connecting to codespace: %w", err) + return err } - defer safeClose(session, &err) + defer closeSession(&err) a.StartProgressIndicatorWithLabel("Fetching SSH Details") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)