diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index ba23ef8b4..6fd4be6c4 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -6,6 +6,7 @@ import ( "context" "errors" "fmt" + "io" "os" "sort" @@ -93,6 +94,12 @@ func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.Use return codespace, token, nil } +func safeClose(closer io.Closer, err *error) { + if closeErr := closer.Close(); *err == nil { + *err = closeErr + } +} + // hasTTY indicates whether the process connected to a terminal. // It is not portable to assume stdin/stdout are fds 0 and 1. var hasTTY = term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd())) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index dc4cc9c98..536471a2d 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -40,7 +40,7 @@ func init() { rootCmd.AddCommand(newLogsCmd()) } -func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) error { +func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) (err error) { // Ensure all child tasks (port forwarding, remote exec) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -66,6 +66,7 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow if err != nil { return fmt.Errorf("connecting to Live Share: %w", err) } + defer safeClose(session, &err) if err := <-authkeys; err != nil { return err diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 7bc53c441..f48dd1e6d 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -51,7 +51,7 @@ func init() { rootCmd.AddCommand(newPortsCmd()) } -func ports(codespaceName string, asJSON bool) error { +func ports(codespaceName string, asJSON bool) (err error) { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, asJSON) @@ -76,6 +76,7 @@ func ports(codespaceName string, asJSON bool) error { if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) log.Println("Loading ports...") ports, err := session.GetSharedServers(ctx) @@ -198,7 +199,7 @@ func newPortsPrivateCmd() *cobra.Command { } } -func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) error { +func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) (err error) { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) @@ -219,6 +220,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) port, err := strconv.Atoi(sourcePort) if err != nil { @@ -260,7 +262,7 @@ func newPortsForwardCmd() *cobra.Command { } } -func forwardPorts(log *output.Logger, codespaceName string, ports []string) error { +func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err error) { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) @@ -286,6 +288,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) // Run forwarding of all ports concurrently, aborting all of // them at the first failure, including cancellation of the context. diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 4cab4a41b..bac41472a 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -36,7 +36,7 @@ func init() { rootCmd.AddCommand(newSSHCmd()) } -func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) error { +func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -63,6 +63,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) if err := <-authkeys; err != nil { return err diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 6235ca3a0..2933c9d8d 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -23,6 +23,8 @@ func connectionReady(codespace *api.Codespace) bool { codespace.Environment.State == api.CodespaceEnvironmentStateAvailable } +// ConnectToLiveshare creates a Live Share client and joins the Live Share session. +// It will start the Codespace if it is not already running, it will time out after 60 seconds if fails to start. func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 408f11941..31105d576 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -36,7 +36,7 @@ type PostCreateState struct { // PollPostCreateStates watches for state changes in a codespace, // and calls the supplied poller for each batch of state changes. // It runs until it encounters an error, including cancellation of the context. -func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { +func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { return fmt.Errorf("getting codespace token: %w", err) @@ -46,6 +46,11 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u if err != nil { return fmt.Errorf("connect to Live Share: %w", err) } + defer func() { + if closeErr := session.Close(); err == nil { + err = closeErr + } + }() // Ensure local port is listening before client (getPostCreateOutput) connects. listen, err := net.Listen("tcp", ":0") // arbitrary port