diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index f15d2fef6..fc7acef2f 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -6,6 +6,7 @@ import ( "context" "errors" "fmt" + "io" "os" "sort" @@ -93,6 +94,12 @@ func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.Use return codespace, token, nil } +func safeClose(closer io.Closer, err *error) { + if closeErr := closer.Close(); *err == nil { + *err = closeErr + } +} + // hasTTY indicates whether the process connected to a terminal. // It is not portable to assume stdin/stdout are fds 0 and 1. var hasTTY = term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd())) @@ -120,3 +127,17 @@ func ask(qs []*survey.Question, response interface{}) error { } return err } + +// checkAuthorizedKeys reports an error if the user has not registered any SSH keys; +// see https://github.com/github/ghcs/issues/166#issuecomment-921769703. +// The check is not required for security but it improves the error message. +func checkAuthorizedKeys(ctx context.Context, client *api.API, user string) error { + keys, err := client.AuthorizedKeys(ctx, user) + if err != nil { + return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err) + } + if len(keys) == 0 { + return fmt.Errorf("user %s has no GitHub-authorized SSH keys", user) + } + return nil // success +} diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 1ce008120..70311c884 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -27,6 +27,9 @@ func newDeleteCmd() *cobra.Command { Use: "delete", Short: "Delete a codespace", RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + return fmt.Errorf("delete: unexpected positional arguments") + } switch { case allCodespaces && repo != "": return errors.New("both --all and --repo is not supported") diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 2b50effd1..514c36966 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -36,7 +36,7 @@ func newLogsCmd() *cobra.Command { return logsCmd } -func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) error { +func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) (err error) { // Ensure all child tasks (port forwarding, remote exec) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -48,6 +48,11 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow return fmt.Errorf("getting user: %w", err) } + authkeys := make(chan error, 1) + go func() { + authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) + }() + codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) @@ -57,6 +62,11 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow if err != nil { return fmt.Errorf("connecting to Live Share: %w", err) } + defer safeClose(session, &err) + + if err := <-authkeys; err != nil { + return err + } // Ensure local port is listening before client (getPostCreateOutput) connects. listen, err := net.Listen("tcp", ":0") // arbitrary port diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index ebfd281cd..aeecf0a07 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -47,8 +47,8 @@ func newPortsCmd() *cobra.Command { return portsCmd } -func ports(codespaceName string, asJSON bool) error { - apiClient := api.New(GithubToken) +func ports(codespaceName string, asJSON bool) (err error) { + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, asJSON) @@ -72,6 +72,7 @@ func ports(codespaceName string, asJSON bool) error { if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) log.Println("Loading ports...") ports, err := session.GetSharedServers(ctx) @@ -194,7 +195,7 @@ func newPortsPrivateCmd() *cobra.Command { } } -func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) error { +func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) (err error) { ctx := context.Background() apiClient := api.New(GithubToken) @@ -215,6 +216,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) port, err := strconv.Atoi(sourcePort) if err != nil { @@ -256,7 +258,7 @@ func newPortsForwardCmd() *cobra.Command { } } -func forwardPorts(log *output.Logger, codespaceName string, ports []string) error { +func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err error) { ctx := context.Background() apiClient := api.New(GithubToken) @@ -282,6 +284,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) // Run forwarding of all ports concurrently, aborting all of // them at the first failure, including cancellation of the context. diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index e4435853b..3a49e6ebc 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -32,7 +32,7 @@ func newSSHCmd() *cobra.Command { return sshCmd } -func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) error { +func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -45,6 +45,11 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string return fmt.Errorf("error getting user: %w", err) } + authkeys := make(chan error, 1) + go func() { + authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) + }() + codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) @@ -54,6 +59,11 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer safeClose(session, &err) + + if err := <-authkeys; err != nil { + return err + } log.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) diff --git a/internal/api/api.go b/internal/api/api.go index 1246389e8..2dd4d71b2 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -13,6 +13,7 @@ package api // - github.GetUser(github.Client) // - github.GetRepository(Client) // - github.ReadFile(Client, nwo, branch, path) // was GetCodespaceRepositoryContents +// - github.AuthorizedKeys(Client, user) // - codespaces.Create(Client, user, repo, sku, branch, location) // - codespaces.Delete(Client, user, token, name) // - codespaces.Get(Client, token, owner, name) @@ -507,6 +508,31 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod return decoded, nil } +// AuthorizedKeys returns the public keys (in ~/.ssh/authorized_keys +// format) registered by the specified GitHub user. +func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) { + url := fmt.Sprintf("https://github.com/%s.keys", user) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp, err := a.do(ctx, req, "/user.keys") + if err != nil { + return nil, err + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("server returned %s", resp.Status) + } + return b, nil +} + func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) { // TODO(adonovan): use NewRequestWithContext(ctx) and drop ctx parameter. span, ctx := opentracing.StartSpanFromContext(ctx, spanName) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 6235ca3a0..2933c9d8d 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -23,6 +23,8 @@ func connectionReady(codespace *api.Codespace) bool { codespace.Environment.State == api.CodespaceEnvironmentStateAvailable } +// ConnectToLiveshare creates a Live Share client and joins the Live Share session. +// It will start the Codespace if it is not already running, it will time out after 60 seconds if fails to start. func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 408f11941..31105d576 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -36,7 +36,7 @@ type PostCreateState struct { // PollPostCreateStates watches for state changes in a codespace, // and calls the supplied poller for each batch of state changes. // It runs until it encounters an error, including cancellation of the context. -func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { +func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { return fmt.Errorf("getting codespace token: %w", err) @@ -46,6 +46,11 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u if err != nil { return fmt.Errorf("connect to Live Share: %w", err) } + defer func() { + if closeErr := session.Close(); err == nil { + err = closeErr + } + }() // Ensure local port is listening before client (getPostCreateOutput) connects. listen, err := net.Listen("tcp", ":0") // arbitrary port