diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index d905c75ac..9f09438d5 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -1,10 +1,9 @@ -package main +package ghcs import ( "context" "fmt" "net/url" - "os" "github.com/github/ghcs/internal/api" "github.com/skratchdot/open-golang/open" @@ -32,12 +31,8 @@ func newCodeCmd() *cobra.Command { return codeCmd } -func init() { - rootCmd.AddCommand(newCodeCmd()) -} - func code(codespaceName string, useInsiders bool) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) ctx := context.Background() user, err := apiClient.GetUser(ctx) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index e71e3dfe4..fc7acef2f 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -1,4 +1,4 @@ -package main +package ghcs // This file defines functions common to the entire ghcs command set. @@ -6,6 +6,7 @@ import ( "context" "errors" "fmt" + "io" "os" "sort" @@ -93,6 +94,12 @@ func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.Use return codespace, token, nil } +func safeClose(closer io.Closer, err *error) { + if closeErr := closer.Close(); *err == nil { + *err = closeErr + } +} + // hasTTY indicates whether the process connected to a terminal. // It is not portable to assume stdin/stdout are fds 0 and 1. var hasTTY = term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd())) @@ -120,3 +127,17 @@ func ask(qs []*survey.Question, response interface{}) error { } return err } + +// checkAuthorizedKeys reports an error if the user has not registered any SSH keys; +// see https://github.com/github/ghcs/issues/166#issuecomment-921769703. +// The check is not required for security but it improves the error message. +func checkAuthorizedKeys(ctx context.Context, client *api.API, user string) error { + keys, err := client.AuthorizedKeys(ctx, user) + if err != nil { + return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err) + } + if len(keys) == 0 { + return fmt.Errorf("user %s has no GitHub-authorized SSH keys", user) + } + return nil // success +} diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 93016bbf8..fd54a170c 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" @@ -43,13 +43,9 @@ func newCreateCmd() *cobra.Command { return createCmd } -func init() { - rootCmd.AddCommand(newCreateCmd()) -} - func create(opts *createOptions) error { ctx := context.Background() - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) locationCh := getLocation(ctx, apiClient) userCh := getUser(ctx, apiClient) log := output.NewLogger(os.Stdout, os.Stderr, false) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index eb00e567f..70311c884 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" @@ -27,6 +27,9 @@ func newDeleteCmd() *cobra.Command { Use: "delete", Short: "Delete a codespace", RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + return fmt.Errorf("delete: unexpected positional arguments") + } switch { case allCodespaces && repo != "": return errors.New("both --all and --repo is not supported") @@ -48,12 +51,8 @@ func newDeleteCmd() *cobra.Command { return deleteCmd } -func init() { - rootCmd.AddCommand(newDeleteCmd()) -} - func delete_(log *output.Logger, codespaceName string, force bool) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) ctx := context.Background() user, err := apiClient.GetUser(ctx) @@ -85,7 +84,7 @@ func delete_(log *output.Logger, codespaceName string, force bool) error { } func deleteAll(log *output.Logger, force bool) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) ctx := context.Background() user, err := apiClient.GetUser(ctx) @@ -124,7 +123,7 @@ func deleteAll(log *output.Logger, force bool) error { } func deleteByRepo(log *output.Logger, repo string, force bool) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) ctx := context.Background() user, err := apiClient.GetUser(ctx) diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index fb8d83c78..7ee156012 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" @@ -31,12 +31,8 @@ func newListCmd() *cobra.Command { return listCmd } -func init() { - rootCmd.AddCommand(newListCmd()) -} - func list(opts *listOptions) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) ctx := context.Background() user, err := apiClient.GetUser(ctx) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 19528061a..514c36966 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" @@ -36,22 +36,23 @@ func newLogsCmd() *cobra.Command { return logsCmd } -func init() { - rootCmd.AddCommand(newLogsCmd()) -} - -func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) error { +func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) (err error) { // Ensure all child tasks (port forwarding, remote exec) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) user, err := apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("getting user: %w", err) } + authkeys := make(chan error, 1) + go func() { + authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) + }() + codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) @@ -61,6 +62,11 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow if err != nil { return fmt.Errorf("connecting to Live Share: %w", err) } + defer safeClose(session, &err) + + if err := <-authkeys; err != nil { + return err + } // Ensure local port is listening before client (getPostCreateOutput) connects. listen, err := net.Listen("tcp", ":0") // arbitrary port diff --git a/cmd/ghcs/main/main.go b/cmd/ghcs/main/main.go new file mode 100644 index 000000000..01dde1270 --- /dev/null +++ b/cmd/ghcs/main/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "errors" + "fmt" + "io" + "os" + + "github.com/github/ghcs/cmd/ghcs" +) + +func main() { + rootCmd := ghcs.NewRootCmd() + if err := rootCmd.Execute(); err != nil { + explainError(os.Stderr, err) + os.Exit(1) + } +} + +func explainError(w io.Writer, err error) { + if errors.Is(err, ghcs.ErrTokenMissing) { + fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo") + fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") + return + } +} diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 7bc53c441..aeecf0a07 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "bytes" @@ -47,11 +47,7 @@ func newPortsCmd() *cobra.Command { return portsCmd } -func init() { - rootCmd.AddCommand(newPortsCmd()) -} - -func ports(codespaceName string, asJSON bool) error { +func ports(codespaceName string, asJSON bool) (err error) { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, asJSON) @@ -76,6 +72,7 @@ func ports(codespaceName string, asJSON bool) error { if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) log.Println("Loading ports...") ports, err := session.GetSharedServers(ctx) @@ -198,9 +195,9 @@ func newPortsPrivateCmd() *cobra.Command { } } -func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) error { +func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) (err error) { ctx := context.Background() - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) user, err := apiClient.GetUser(ctx) if err != nil { @@ -219,6 +216,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) port, err := strconv.Atoi(sourcePort) if err != nil { @@ -260,9 +258,9 @@ func newPortsForwardCmd() *cobra.Command { } } -func forwardPorts(log *output.Logger, codespaceName string, ports []string) error { +func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err error) { ctx := context.Background() - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) portPairs, err := getPortPairs(ports) if err != nil { @@ -286,6 +284,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) // Run forwarding of all ports concurrently, aborting all of // them at the first failure, including cancellation of the context. diff --git a/cmd/ghcs/main.go b/cmd/ghcs/root.go similarity index 77% rename from cmd/ghcs/main.go rename to cmd/ghcs/root.go index 7903dad2a..6db4144a8 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/root.go @@ -1,9 +1,8 @@ -package main +package ghcs import ( "errors" "fmt" - "io" "log" "os" "strconv" @@ -14,18 +13,12 @@ import ( "github.com/spf13/cobra" ) -func main() { - if err := rootCmd.Execute(); err != nil { - explainError(os.Stderr, err) - os.Exit(1) - } -} - var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number. -var rootCmd = newRootCmd() +// GithubToken is a temporary stopgap to make the token configurable by apps that import this package +var GithubToken = os.Getenv("GITHUB_TOKEN") -func newRootCmd() *cobra.Command { +func NewRootCmd() *cobra.Command { var lightstep string root := &cobra.Command{ @@ -40,7 +33,7 @@ token to access the GitHub API with.`, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { if os.Getenv("GITHUB_TOKEN") == "" { - return errTokenMissing + return ErrTokenMissing } return initLightstep(lightstep) }, @@ -48,18 +41,18 @@ token to access the GitHub API with.`, root.PersistentFlags().StringVar(&lightstep, "lightstep", "", "Lightstep tracing endpoint (service:token@host:port)") + root.AddCommand(newCodeCmd()) + root.AddCommand(newCreateCmd()) + root.AddCommand(newDeleteCmd()) + root.AddCommand(newListCmd()) + root.AddCommand(newLogsCmd()) + root.AddCommand(newPortsCmd()) + root.AddCommand(newSSHCmd()) + return root } -var errTokenMissing = errors.New("GITHUB_TOKEN is missing") - -func explainError(w io.Writer, err error) { - if errors.Is(err, errTokenMissing) { - fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo") - fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") - return - } -} +var ErrTokenMissing = errors.New("GITHUB_TOKEN is missing") // initLightstep parses the --lightstep=service:token@host:port flag and // enables tracing if non-empty. diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 4ece84d91..3a49e6ebc 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" @@ -32,16 +32,12 @@ func newSSHCmd() *cobra.Command { return sshCmd } -func init() { - rootCmd.AddCommand(newSSHCmd()) -} - -func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) error { +func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) @@ -49,6 +45,11 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string return fmt.Errorf("error getting user: %w", err) } + authkeys := make(chan error, 1) + go func() { + authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) + }() + codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) @@ -58,6 +59,11 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) + + if err := <-authkeys; err != nil { + return err + } log.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) diff --git a/internal/api/api.go b/internal/api/api.go index df9fd10c7..394efc6af 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -13,6 +13,7 @@ package api // - github.GetUser(github.Client) // - github.GetRepository(Client) // - github.ReadFile(Client, nwo, branch, path) // was GetCodespaceRepositoryContents +// - github.AuthorizedKeys(Client, user) // - codespaces.Create(Client, user, repo, sku, branch, location) // - codespaces.Delete(Client, user, token, name) // - codespaces.Get(Client, token, owner, name) @@ -512,6 +513,31 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod return decoded, nil } +// AuthorizedKeys returns the public keys (in ~/.ssh/authorized_keys +// format) registered by the specified GitHub user. +func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) { + url := fmt.Sprintf("https://github.com/%s.keys", user) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp, err := a.do(ctx, req, "/user.keys") + if err != nil { + return nil, err + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("server returned %s", resp.Status) + } + return b, nil +} + func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) { // TODO(adonovan): use NewRequestWithContext(ctx) and drop ctx parameter. span, ctx := opentracing.StartSpanFromContext(ctx, spanName) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 6235ca3a0..2933c9d8d 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -23,6 +23,8 @@ func connectionReady(codespace *api.Codespace) bool { codespace.Environment.State == api.CodespaceEnvironmentStateAvailable } +// ConnectToLiveshare creates a Live Share client and joins the Live Share session. +// It will start the Codespace if it is not already running, it will time out after 60 seconds if fails to start. func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 408f11941..31105d576 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -36,7 +36,7 @@ type PostCreateState struct { // PollPostCreateStates watches for state changes in a codespace, // and calls the supplied poller for each batch of state changes. // It runs until it encounters an error, including cancellation of the context. -func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { +func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { return fmt.Errorf("getting codespace token: %w", err) @@ -46,6 +46,11 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u if err != nil { return fmt.Errorf("connect to Live Share: %w", err) } + defer func() { + if closeErr := session.Close(); err == nil { + err = closeErr + } + }() // Ensure local port is listening before client (getPostCreateOutput) connects. listen, err := net.Listen("tcp", ":0") // arbitrary port