From ca0f89d3bc1bbf2292ec4e0e2b3fbf97e1047fd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 24 Sep 2021 16:03:44 +0200 Subject: [PATCH 1/3] Introduce an App struct that executes core business logic The Cobra commands are now a light wrapper around the App struct. Co-authored-by: Jose Garcia --- cmd/ghcs/code.go | 15 +- cmd/ghcs/common.go | 35 ++- cmd/ghcs/create.go | 36 ++- cmd/ghcs/delete.go | 36 +-- cmd/ghcs/delete_test.go | 8 +- cmd/ghcs/list.go | 13 +- cmd/ghcs/logs.go | 23 +- cmd/ghcs/main/main.go | 27 ++- cmd/ghcs/mock_api.go | 383 +++++++++++++++++++++++++++++- cmd/ghcs/ports.go | 72 +++--- cmd/ghcs/root.go | 26 +- cmd/ghcs/ssh.go | 27 +-- internal/api/api.go | 21 +- internal/codespaces/codespaces.go | 8 +- internal/codespaces/states.go | 2 +- 15 files changed, 557 insertions(+), 175 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 08d2cff1a..cfcd989e2 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -5,12 +5,11 @@ import ( "fmt" "net/url" - "github.com/github/ghcs/internal/api" "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" ) -func newCodeCmd() *cobra.Command { +func newCodeCmd(app *App) *cobra.Command { var ( codespace string useInsiders bool @@ -21,7 +20,7 @@ func newCodeCmd() *cobra.Command { Short: "Open a codespace in VS Code", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return code(codespace, useInsiders) + return app.VSCode(cmd.Context(), codespace, useInsiders) }, } @@ -31,17 +30,15 @@ func newCodeCmd() *cobra.Command { return codeCmd } -func code(codespaceName string, useInsiders bool) error { - apiClient := api.New(GithubToken) - ctx := context.Background() - - user, err := apiClient.GetUser(ctx) +// VSCode opens a codespace in the local VS VSCode application. +func (a *App) VSCode(ctx context.Context, codespaceName string, useInsiders bool) error { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } if codespaceName == "" { - codespace, err := chooseCodespace(ctx, apiClient, user) + codespace, err := chooseCodespace(ctx, a.apiClient, user) if err != nil { if err == errNoCodespaces { return err diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index 371ca30b8..e60fa7c96 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -12,14 +12,43 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/AlecAivazis/survey/v2/terminal" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" "golang.org/x/term" ) +type App struct { + apiClient apiClient + logger *output.Logger +} + +func NewApp(logger *output.Logger, apiClient apiClient) *App { + return &App{ + apiClient: apiClient, + logger: logger, + } +} + +//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient +type apiClient interface { + GetUser(ctx context.Context) (*api.User, error) + GetCodespaceToken(ctx context.Context, user, name string) (string, error) + GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) + ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) + DeleteCodespace(ctx context.Context, user, name string) error + StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error + CreateCodespace(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) + GetRepository(ctx context.Context, nwo string) (*api.Repository, error) + AuthorizedKeys(ctx context.Context, user string) ([]byte, error) + GetCodespaceRegionLocation(ctx context.Context) (string, error) + GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch, location string) ([]*api.SKU, error) + GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) +} + var errNoCodespaces = errors.New("you have no codespaces") -func chooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) { +func chooseCodespace(ctx context.Context, apiClient apiClient, user *api.User) (*api.Codespace, error) { codespaces, err := apiClient.ListCodespaces(ctx, user.Login) if err != nil { return nil, fmt.Errorf("error getting codespaces: %w", err) @@ -68,7 +97,7 @@ func chooseCodespaceFromList(ctx context.Context, codespaces []*api.Codespace) ( // getOrChooseCodespace prompts the user to choose a codespace if the codespaceName is empty. // It then fetches the codespace token and the codespace record. -func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { +func getOrChooseCodespace(ctx context.Context, apiClient apiClient, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { if codespaceName == "" { codespace, err = chooseCodespace(ctx, apiClient, user) if err != nil { @@ -135,7 +164,7 @@ func ask(qs []*survey.Question, response interface{}) error { // checkAuthorizedKeys reports an error if the user has not registered any SSH keys; // see https://github.com/github/ghcs/issues/166#issuecomment-921769703. // The check is not required for security but it improves the error message. -func checkAuthorizedKeys(ctx context.Context, client *api.API, user string) error { +func checkAuthorizedKeys(ctx context.Context, client apiClient, user string) error { keys, err := client.AuthorizedKeys(ctx, user) if err != nil { return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 345489f6b..c92a6edff 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -22,7 +22,7 @@ type createOptions struct { showStatus bool } -func newCreateCmd() *cobra.Command { +func newCreateCmd(app *App) *cobra.Command { opts := &createOptions{} createCmd := &cobra.Command{ @@ -30,7 +30,7 @@ func newCreateCmd() *cobra.Command { Short: "Create a codespace", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return create(opts) + return app.Create(cmd.Context(), opts) }, } @@ -42,12 +42,10 @@ func newCreateCmd() *cobra.Command { return createCmd } -func create(opts *createOptions) error { - ctx := context.Background() - apiClient := api.New(GithubToken) - locationCh := getLocation(ctx, apiClient) - userCh := getUser(ctx, apiClient) - log := output.NewLogger(os.Stdout, os.Stderr, false) +// Create creates a new Codespace +func (a *App) Create(ctx context.Context, opts *createOptions) error { + locationCh := getLocation(ctx, a.apiClient) + userCh := getUser(ctx, a.apiClient) repo, err := getRepoName(opts.repo) if err != nil { @@ -58,7 +56,7 @@ func create(opts *createOptions) error { return fmt.Errorf("error getting branch name: %w", err) } - repository, err := apiClient.GetRepository(ctx, repo) + repository, err := a.apiClient.GetRepository(ctx, repo) if err != nil { return fmt.Errorf("error getting repository: %w", err) } @@ -73,7 +71,7 @@ func create(opts *createOptions) error { return fmt.Errorf("error getting codespace user: %w", userResult.Err) } - machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, apiClient) + machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, a.apiClient) if err != nil { return fmt.Errorf("error getting machine type: %w", err) } @@ -81,26 +79,26 @@ func create(opts *createOptions) error { return errors.New("there are no available machine types for this repository") } - log.Print("Creating your codespace...") - codespace, err := apiClient.CreateCodespace(ctx, log, &api.CreateCodespaceParams{ + a.logger.Print("Creating your codespace...") + codespace, err := a.apiClient.CreateCodespace(ctx, a.logger, &api.CreateCodespaceParams{ User: userResult.User.Login, RepositoryID: repository.ID, Branch: branch, Machine: machine, Location: locationResult.Location, }) - log.Print("\n") + a.logger.Print("\n") if err != nil { return fmt.Errorf("error creating codespace: %w", err) } if opts.showStatus { - if err := showStatus(ctx, log, apiClient, userResult.User, codespace); err != nil { + if err := showStatus(ctx, a.logger, a.apiClient, userResult.User, codespace); err != nil { return fmt.Errorf("show status: %w", err) } } - log.Printf("Codespace created: ") + a.logger.Printf("Codespace created: ") fmt.Fprintln(os.Stdout, codespace.Name) @@ -110,7 +108,7 @@ func create(opts *createOptions) error { // showStatus polls the codespace for a list of post create states and their status. It will keep polling // until all states have finished. Once all states have finished, we poll once more to check if any new // states have been introduced and stop polling otherwise. -func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, user *api.User, codespace *api.Codespace) error { +func showStatus(ctx context.Context, log *output.Logger, apiClient apiClient, user *api.User, codespace *api.Codespace) error { var lastState codespaces.PostCreateState var breakNextState bool @@ -177,7 +175,7 @@ type getUserResult struct { } // getUser fetches the user record associated with the GITHUB_TOKEN -func getUser(ctx context.Context, apiClient *api.API) <-chan getUserResult { +func getUser(ctx context.Context, apiClient apiClient) <-chan getUserResult { ch := make(chan getUserResult, 1) go func() { user, err := apiClient.GetUser(ctx) @@ -192,7 +190,7 @@ type locationResult struct { } // getLocation fetches the closest Codespace datacenter region/location to the user. -func getLocation(ctx context.Context, apiClient *api.API) <-chan locationResult { +func getLocation(ctx context.Context, apiClient apiClient) <-chan locationResult { ch := make(chan locationResult, 1) go func() { location, err := apiClient.GetCodespaceRegionLocation(ctx) @@ -236,7 +234,7 @@ func getBranchName(branch string) (string, error) { } // getMachineName prompts the user to select the machine type, or validates the machine if non-empty. -func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient *api.API) (string, error) { +func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient apiClient) (string, error) { skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location) if err != nil { return "", fmt.Errorf("error requesting machine instance types: %w", err) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index b5d25e7bb..d7fed4e68 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -4,12 +4,10 @@ import ( "context" "errors" "fmt" - "os" "strings" "time" "github.com/AlecAivazis/survey/v2" - "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" @@ -24,7 +22,6 @@ type deleteOptions struct { isInteractive bool now func() time.Time - apiClient apiClient prompter prompter } @@ -33,20 +30,10 @@ type prompter interface { Confirm(message string) (bool, error) } -//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient -type apiClient interface { - GetUser(ctx context.Context) (*api.User, error) - GetCodespaceToken(ctx context.Context, user, name string) (string, error) - GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) - ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) - DeleteCodespace(ctx context.Context, user, name string) error -} - -func newDeleteCmd() *cobra.Command { +func newDeleteCmd(app *App) *cobra.Command { opts := deleteOptions{ isInteractive: hasTTY, now: time.Now, - apiClient: api.New(os.Getenv("GITHUB_TOKEN")), prompter: &surveyPrompter{}, } @@ -58,8 +45,7 @@ func newDeleteCmd() *cobra.Command { if opts.deleteAll && opts.repoFilter != "" { return errors.New("both --all and --repo is not supported") } - log := output.NewLogger(cmd.OutOrStdout(), cmd.ErrOrStderr(), !opts.isInteractive) - return delete(context.Background(), log, opts) + return app.Delete(cmd.Context(), opts) }, } @@ -72,12 +58,8 @@ func newDeleteCmd() *cobra.Command { return deleteCmd } -type logger interface { - Errorf(format string, v ...interface{}) (int, error) -} - -func delete(ctx context.Context, log logger, opts deleteOptions) error { - user, err := opts.apiClient.GetUser(ctx) +func (a *App) Delete(ctx context.Context, opts deleteOptions) error { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } @@ -85,7 +67,7 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { var codespaces []*api.Codespace nameFilter := opts.codespaceName if nameFilter == "" { - codespaces, err = opts.apiClient.ListCodespaces(ctx, user.Login) + codespaces, err = a.apiClient.ListCodespaces(ctx, user.Login) if err != nil { return fmt.Errorf("error getting codespaces: %w", err) } @@ -99,12 +81,12 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { } } else { // TODO: this token is discarded and then re-requested later in DeleteCodespace - token, err := opts.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter) + token, err := a.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter) if err != nil { return fmt.Errorf("error getting codespace token: %w", err) } - codespace, err := opts.apiClient.GetCodespace(ctx, token, user.Login, nameFilter) + codespace, err := a.apiClient.GetCodespace(ctx, token, user.Login, nameFilter) if err != nil { return fmt.Errorf("error fetching codespace information: %w", err) } @@ -150,8 +132,8 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { for _, c := range codespacesToDelete { codespaceName := c.Name g.Go(func() error { - if err := opts.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil { - _, _ = log.Errorf("error deleting codespace %q: %v\n", codespaceName, err) + if err := a.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil { + _, _ = a.logger.Errorf("error deleting codespace %q: %v\n", codespaceName, err) return err } return nil diff --git a/cmd/ghcs/delete_test.go b/cmd/ghcs/delete_test.go index 47e6a4d6c..ab7b01d30 100644 --- a/cmd/ghcs/delete_test.go +++ b/cmd/ghcs/delete_test.go @@ -186,7 +186,6 @@ func TestDelete(t *testing.T) { } } opts := tt.opts - opts.apiClient = apiMock opts.now = func() time.Time { return now } opts.prompter = &prompterMock{ ConfirmFunc: func(msg string) (bool, error) { @@ -200,8 +199,11 @@ func TestDelete(t *testing.T) { stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} - log := output.NewLogger(stdout, stderr, false) - err := delete(context.Background(), log, opts) + app := &App{ + apiClient: apiMock, + logger: output.NewLogger(stdout, stderr, false), + } + err := app.Delete(context.Background(), opts) if (err != nil) != tt.wantErr { t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 065b7aa6d..842b9313d 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -14,7 +14,7 @@ type listOptions struct { asJSON bool } -func newListCmd() *cobra.Command { +func newListCmd(app *App) *cobra.Command { opts := &listOptions{} listCmd := &cobra.Command{ @@ -22,7 +22,7 @@ func newListCmd() *cobra.Command { Short: "List your codespaces", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return list(opts) + return app.List(cmd.Context(), opts) }, } @@ -31,16 +31,13 @@ func newListCmd() *cobra.Command { return listCmd } -func list(opts *listOptions) error { - apiClient := api.New(GithubToken) - ctx := context.Background() - - user, err := apiClient.GetUser(ctx) +func (a *App) List(ctx context.Context, opts *listOptions) error { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespaces, err := apiClient.ListCodespaces(ctx, user.Login) + codespaces, err := a.apiClient.ListCodespaces(ctx, user.Login) if err != nil { return fmt.Errorf("error getting codespaces: %w", err) } diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 0cddc6377..7f73d893c 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -4,29 +4,24 @@ import ( "context" "fmt" "net" - "os" - "github.com/github/ghcs/cmd/ghcs/output" - "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) -func newLogsCmd() *cobra.Command { +func newLogsCmd(app *App) *cobra.Command { var ( codespace string follow bool ) - log := output.NewLogger(os.Stdout, os.Stderr, false) - logsCmd := &cobra.Command{ Use: "logs", Short: "Access codespace logs", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return logs(context.Background(), log, codespace, follow) + return app.Logs(cmd.Context(), codespace, follow) }, } @@ -36,29 +31,27 @@ func newLogsCmd() *cobra.Command { return logsCmd } -func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) (err error) { +func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err error) { // Ensure all child tasks (port forwarding, remote exec) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() - apiClient := api.New(GithubToken) - - user, err := apiClient.GetUser(ctx) + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("getting user: %w", err) } authkeys := make(chan error, 1) go func() { - authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) + authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login) }() - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("connecting to Live Share: %w", err) } @@ -76,7 +69,7 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow defer listen.Close() localPort := listen.Addr().(*net.TCPAddr).Port - log.Println("Fetching SSH Details...") + a.logger.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { return fmt.Errorf("error getting ssh server details: %w", err) diff --git a/cmd/ghcs/main/main.go b/cmd/ghcs/main/main.go index 6b890d740..7c6b2a175 100644 --- a/cmd/ghcs/main/main.go +++ b/cmd/ghcs/main/main.go @@ -4,22 +4,45 @@ import ( "errors" "fmt" "io" + "net/http" "os" "github.com/github/ghcs/cmd/ghcs" + "github.com/github/ghcs/cmd/ghcs/output" + "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" ) func main() { - rootCmd := ghcs.NewRootCmd() + token := os.Getenv("GITHUB_TOKEN") + rootCmd := ghcs.NewRootCmd(ghcs.NewApp( + output.NewLogger(os.Stdout, os.Stderr, false), + api.New(token, http.DefaultClient), + )) + + // Require GITHUB_TOKEN through a Cobra pre-run hook so that Cobra's help system for commands can still + // function without the token set. + oldPreRun := rootCmd.PersistentPreRunE + rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { + if token == "" { + return errTokenMissing + } + if oldPreRun != nil { + return oldPreRun(cmd, args) + } + return nil + } + if cmd, err := rootCmd.ExecuteC(); err != nil { explainError(os.Stderr, err, cmd) os.Exit(1) } } +var errTokenMissing = errors.New("GITHUB_TOKEN is missing") + func explainError(w io.Writer, err error, cmd *cobra.Command) { - if errors.Is(err, ghcs.ErrTokenMissing) { + if errors.Is(err, errTokenMissing) { fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo") fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") return diff --git a/cmd/ghcs/mock_api.go b/cmd/ghcs/mock_api.go index 256a30ec3..93abe7ed6 100644 --- a/cmd/ghcs/mock_api.go +++ b/cmd/ghcs/mock_api.go @@ -16,21 +16,42 @@ import ( // // // make and configure a mocked apiClient // mockedapiClient := &apiClientMock{ +// AuthorizedKeysFunc: func(ctx context.Context, user string) ([]byte, error) { +// panic("mock out the AuthorizedKeys method") +// }, +// CreateCodespaceFunc: func(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) { +// panic("mock out the CreateCodespace method") +// }, // DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error { // panic("mock out the DeleteCodespace method") // }, // GetCodespaceFunc: func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) { // panic("mock out the GetCodespace method") // }, +// GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) { +// panic("mock out the GetCodespaceRegionLocation method") +// }, +// GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) { +// panic("mock out the GetCodespaceRepositoryContents method") +// }, // GetCodespaceTokenFunc: func(ctx context.Context, user string, name string) (string, error) { // panic("mock out the GetCodespaceToken method") // }, +// GetCodespacesSKUsFunc: func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) { +// panic("mock out the GetCodespacesSKUs method") +// }, +// GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { +// panic("mock out the GetRepository method") +// }, // GetUserFunc: func(ctx context.Context) (*api.User, error) { // panic("mock out the GetUser method") // }, // ListCodespacesFunc: func(ctx context.Context, user string) ([]*api.Codespace, error) { // panic("mock out the ListCodespaces method") // }, +// StartCodespaceFunc: func(ctx context.Context, token string, codespace *api.Codespace) error { +// panic("mock out the StartCodespace method") +// }, // } // // // use mockedapiClient in code that requires apiClient @@ -38,23 +59,60 @@ import ( // // } type apiClientMock struct { + // AuthorizedKeysFunc mocks the AuthorizedKeys method. + AuthorizedKeysFunc func(ctx context.Context, user string) ([]byte, error) + + // CreateCodespaceFunc mocks the CreateCodespace method. + CreateCodespaceFunc func(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) + // DeleteCodespaceFunc mocks the DeleteCodespace method. DeleteCodespaceFunc func(ctx context.Context, user string, name string) error // GetCodespaceFunc mocks the GetCodespace method. GetCodespaceFunc func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) + // GetCodespaceRegionLocationFunc mocks the GetCodespaceRegionLocation method. + GetCodespaceRegionLocationFunc func(ctx context.Context) (string, error) + + // GetCodespaceRepositoryContentsFunc mocks the GetCodespaceRepositoryContents method. + GetCodespaceRepositoryContentsFunc func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) + // GetCodespaceTokenFunc mocks the GetCodespaceToken method. GetCodespaceTokenFunc func(ctx context.Context, user string, name string) (string, error) + // GetCodespacesSKUsFunc mocks the GetCodespacesSKUs method. + GetCodespacesSKUsFunc func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) + + // GetRepositoryFunc mocks the GetRepository method. + GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error) + // GetUserFunc mocks the GetUser method. GetUserFunc func(ctx context.Context) (*api.User, error) // ListCodespacesFunc mocks the ListCodespaces method. ListCodespacesFunc func(ctx context.Context, user string) ([]*api.Codespace, error) + // StartCodespaceFunc mocks the StartCodespace method. + StartCodespaceFunc func(ctx context.Context, token string, codespace *api.Codespace) error + // calls tracks calls to the methods. calls struct { + // AuthorizedKeys holds details about calls to the AuthorizedKeys method. + AuthorizedKeys []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User string + } + // CreateCodespace holds details about calls to the CreateCodespace method. + CreateCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Logger is the logger argument value. + Logger api.Logger + // Params is the params argument value. + Params *api.CreateCodespaceParams + } // DeleteCodespace holds details about calls to the DeleteCodespace method. DeleteCodespace []struct { // Ctx is the ctx argument value. @@ -75,6 +133,20 @@ type apiClientMock struct { // Name is the name argument value. Name string } + // GetCodespaceRegionLocation holds details about calls to the GetCodespaceRegionLocation method. + GetCodespaceRegionLocation []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // GetCodespaceRepositoryContents holds details about calls to the GetCodespaceRepositoryContents method. + GetCodespaceRepositoryContents []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Codespace is the codespace argument value. + Codespace *api.Codespace + // Path is the path argument value. + Path string + } // GetCodespaceToken holds details about calls to the GetCodespaceToken method. GetCodespaceToken []struct { // Ctx is the ctx argument value. @@ -84,6 +156,26 @@ type apiClientMock struct { // Name is the name argument value. Name string } + // GetCodespacesSKUs holds details about calls to the GetCodespacesSKUs method. + GetCodespacesSKUs []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User *api.User + // Repository is the repository argument value. + Repository *api.Repository + // Branch is the branch argument value. + Branch string + // Location is the location argument value. + Location string + } + // GetRepository holds details about calls to the GetRepository method. + GetRepository []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Nwo is the nwo argument value. + Nwo string + } // GetUser holds details about calls to the GetUser method. GetUser []struct { // Ctx is the ctx argument value. @@ -96,12 +188,102 @@ type apiClientMock struct { // User is the user argument value. User string } + // StartCodespace holds details about calls to the StartCodespace method. + StartCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Token is the token argument value. + Token string + // Codespace is the codespace argument value. + Codespace *api.Codespace + } } - lockDeleteCodespace sync.RWMutex - lockGetCodespace sync.RWMutex - lockGetCodespaceToken sync.RWMutex - lockGetUser sync.RWMutex - lockListCodespaces sync.RWMutex + lockAuthorizedKeys sync.RWMutex + lockCreateCodespace sync.RWMutex + lockDeleteCodespace sync.RWMutex + lockGetCodespace sync.RWMutex + lockGetCodespaceRegionLocation sync.RWMutex + lockGetCodespaceRepositoryContents sync.RWMutex + lockGetCodespaceToken sync.RWMutex + lockGetCodespacesSKUs sync.RWMutex + lockGetRepository sync.RWMutex + lockGetUser sync.RWMutex + lockListCodespaces sync.RWMutex + lockStartCodespace sync.RWMutex +} + +// AuthorizedKeys calls AuthorizedKeysFunc. +func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) { + if mock.AuthorizedKeysFunc == nil { + panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called") + } + callInfo := struct { + Ctx context.Context + User string + }{ + Ctx: ctx, + User: user, + } + mock.lockAuthorizedKeys.Lock() + mock.calls.AuthorizedKeys = append(mock.calls.AuthorizedKeys, callInfo) + mock.lockAuthorizedKeys.Unlock() + return mock.AuthorizedKeysFunc(ctx, user) +} + +// AuthorizedKeysCalls gets all the calls that were made to AuthorizedKeys. +// Check the length with: +// len(mockedapiClient.AuthorizedKeysCalls()) +func (mock *apiClientMock) AuthorizedKeysCalls() []struct { + Ctx context.Context + User string +} { + var calls []struct { + Ctx context.Context + User string + } + mock.lockAuthorizedKeys.RLock() + calls = mock.calls.AuthorizedKeys + mock.lockAuthorizedKeys.RUnlock() + return calls +} + +// CreateCodespace calls CreateCodespaceFunc. +func (mock *apiClientMock) CreateCodespace(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) { + if mock.CreateCodespaceFunc == nil { + panic("apiClientMock.CreateCodespaceFunc: method is nil but apiClient.CreateCodespace was just called") + } + callInfo := struct { + Ctx context.Context + Logger api.Logger + Params *api.CreateCodespaceParams + }{ + Ctx: ctx, + Logger: logger, + Params: params, + } + mock.lockCreateCodespace.Lock() + mock.calls.CreateCodespace = append(mock.calls.CreateCodespace, callInfo) + mock.lockCreateCodespace.Unlock() + return mock.CreateCodespaceFunc(ctx, logger, params) +} + +// CreateCodespaceCalls gets all the calls that were made to CreateCodespace. +// Check the length with: +// len(mockedapiClient.CreateCodespaceCalls()) +func (mock *apiClientMock) CreateCodespaceCalls() []struct { + Ctx context.Context + Logger api.Logger + Params *api.CreateCodespaceParams +} { + var calls []struct { + Ctx context.Context + Logger api.Logger + Params *api.CreateCodespaceParams + } + mock.lockCreateCodespace.RLock() + calls = mock.calls.CreateCodespace + mock.lockCreateCodespace.RUnlock() + return calls } // DeleteCodespace calls DeleteCodespaceFunc. @@ -186,6 +368,76 @@ func (mock *apiClientMock) GetCodespaceCalls() []struct { return calls } +// GetCodespaceRegionLocation calls GetCodespaceRegionLocationFunc. +func (mock *apiClientMock) GetCodespaceRegionLocation(ctx context.Context) (string, error) { + if mock.GetCodespaceRegionLocationFunc == nil { + panic("apiClientMock.GetCodespaceRegionLocationFunc: method is nil but apiClient.GetCodespaceRegionLocation was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockGetCodespaceRegionLocation.Lock() + mock.calls.GetCodespaceRegionLocation = append(mock.calls.GetCodespaceRegionLocation, callInfo) + mock.lockGetCodespaceRegionLocation.Unlock() + return mock.GetCodespaceRegionLocationFunc(ctx) +} + +// GetCodespaceRegionLocationCalls gets all the calls that were made to GetCodespaceRegionLocation. +// Check the length with: +// len(mockedapiClient.GetCodespaceRegionLocationCalls()) +func (mock *apiClientMock) GetCodespaceRegionLocationCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockGetCodespaceRegionLocation.RLock() + calls = mock.calls.GetCodespaceRegionLocation + mock.lockGetCodespaceRegionLocation.RUnlock() + return calls +} + +// GetCodespaceRepositoryContents calls GetCodespaceRepositoryContentsFunc. +func (mock *apiClientMock) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) { + if mock.GetCodespaceRepositoryContentsFunc == nil { + panic("apiClientMock.GetCodespaceRepositoryContentsFunc: method is nil but apiClient.GetCodespaceRepositoryContents was just called") + } + callInfo := struct { + Ctx context.Context + Codespace *api.Codespace + Path string + }{ + Ctx: ctx, + Codespace: codespace, + Path: path, + } + mock.lockGetCodespaceRepositoryContents.Lock() + mock.calls.GetCodespaceRepositoryContents = append(mock.calls.GetCodespaceRepositoryContents, callInfo) + mock.lockGetCodespaceRepositoryContents.Unlock() + return mock.GetCodespaceRepositoryContentsFunc(ctx, codespace, path) +} + +// GetCodespaceRepositoryContentsCalls gets all the calls that were made to GetCodespaceRepositoryContents. +// Check the length with: +// len(mockedapiClient.GetCodespaceRepositoryContentsCalls()) +func (mock *apiClientMock) GetCodespaceRepositoryContentsCalls() []struct { + Ctx context.Context + Codespace *api.Codespace + Path string +} { + var calls []struct { + Ctx context.Context + Codespace *api.Codespace + Path string + } + mock.lockGetCodespaceRepositoryContents.RLock() + calls = mock.calls.GetCodespaceRepositoryContents + mock.lockGetCodespaceRepositoryContents.RUnlock() + return calls +} + // GetCodespaceToken calls GetCodespaceTokenFunc. func (mock *apiClientMock) GetCodespaceToken(ctx context.Context, user string, name string) (string, error) { if mock.GetCodespaceTokenFunc == nil { @@ -225,6 +477,88 @@ func (mock *apiClientMock) GetCodespaceTokenCalls() []struct { return calls } +// GetCodespacesSKUs calls GetCodespacesSKUsFunc. +func (mock *apiClientMock) GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) { + if mock.GetCodespacesSKUsFunc == nil { + panic("apiClientMock.GetCodespacesSKUsFunc: method is nil but apiClient.GetCodespacesSKUs was just called") + } + callInfo := struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string + }{ + Ctx: ctx, + User: user, + Repository: repository, + Branch: branch, + Location: location, + } + mock.lockGetCodespacesSKUs.Lock() + mock.calls.GetCodespacesSKUs = append(mock.calls.GetCodespacesSKUs, callInfo) + mock.lockGetCodespacesSKUs.Unlock() + return mock.GetCodespacesSKUsFunc(ctx, user, repository, branch, location) +} + +// GetCodespacesSKUsCalls gets all the calls that were made to GetCodespacesSKUs. +// Check the length with: +// len(mockedapiClient.GetCodespacesSKUsCalls()) +func (mock *apiClientMock) GetCodespacesSKUsCalls() []struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string +} { + var calls []struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string + } + mock.lockGetCodespacesSKUs.RLock() + calls = mock.calls.GetCodespacesSKUs + mock.lockGetCodespacesSKUs.RUnlock() + return calls +} + +// GetRepository calls GetRepositoryFunc. +func (mock *apiClientMock) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) { + if mock.GetRepositoryFunc == nil { + panic("apiClientMock.GetRepositoryFunc: method is nil but apiClient.GetRepository was just called") + } + callInfo := struct { + Ctx context.Context + Nwo string + }{ + Ctx: ctx, + Nwo: nwo, + } + mock.lockGetRepository.Lock() + mock.calls.GetRepository = append(mock.calls.GetRepository, callInfo) + mock.lockGetRepository.Unlock() + return mock.GetRepositoryFunc(ctx, nwo) +} + +// GetRepositoryCalls gets all the calls that were made to GetRepository. +// Check the length with: +// len(mockedapiClient.GetRepositoryCalls()) +func (mock *apiClientMock) GetRepositoryCalls() []struct { + Ctx context.Context + Nwo string +} { + var calls []struct { + Ctx context.Context + Nwo string + } + mock.lockGetRepository.RLock() + calls = mock.calls.GetRepository + mock.lockGetRepository.RUnlock() + return calls +} + // GetUser calls GetUserFunc. func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) { if mock.GetUserFunc == nil { @@ -290,3 +624,42 @@ func (mock *apiClientMock) ListCodespacesCalls() []struct { mock.lockListCodespaces.RUnlock() return calls } + +// StartCodespace calls StartCodespaceFunc. +func (mock *apiClientMock) StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error { + if mock.StartCodespaceFunc == nil { + panic("apiClientMock.StartCodespaceFunc: method is nil but apiClient.StartCodespace was just called") + } + callInfo := struct { + Ctx context.Context + Token string + Codespace *api.Codespace + }{ + Ctx: ctx, + Token: token, + Codespace: codespace, + } + mock.lockStartCodespace.Lock() + mock.calls.StartCodespace = append(mock.calls.StartCodespace, callInfo) + mock.lockStartCodespace.Unlock() + return mock.StartCodespaceFunc(ctx, token, codespace) +} + +// StartCodespaceCalls gets all the calls that were made to StartCodespace. +// Check the length with: +// len(mockedapiClient.StartCodespaceCalls()) +func (mock *apiClientMock) StartCodespaceCalls() []struct { + Ctx context.Context + Token string + Codespace *api.Codespace +} { + var calls []struct { + Ctx context.Context + Token string + Codespace *api.Codespace + } + mock.lockStartCodespace.RLock() + calls = mock.calls.StartCodespace + mock.lockStartCodespace.RUnlock() + return calls +} diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index f423245bd..06eabad6d 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -22,7 +22,7 @@ import ( // newPortsCmd returns a Cobra "ports" command that displays a table of available ports, // according to the specified flags. -func newPortsCmd() *cobra.Command { +func newPortsCmd(app *App) *cobra.Command { var ( codespace string asJSON bool @@ -33,31 +33,28 @@ func newPortsCmd() *cobra.Command { Short: "List ports in a codespace", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return ports(codespace, asJSON) + return app.ListPorts(cmd.Context(), codespace, asJSON) }, } portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") portsCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON") - portsCmd.AddCommand(newPortsPublicCmd()) - portsCmd.AddCommand(newPortsPrivateCmd()) - portsCmd.AddCommand(newPortsForwardCmd()) + portsCmd.AddCommand(newPortsPublicCmd(app)) + portsCmd.AddCommand(newPortsPrivateCmd(app)) + portsCmd.AddCommand(newPortsForwardCmd(app)) return portsCmd } -func ports(codespaceName string, asJSON bool) (err error) { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) - ctx := context.Background() - log := output.NewLogger(os.Stdout, os.Stderr, asJSON) - - user, err := apiClient.GetUser(ctx) +// ListPorts lists known ports in a codespace. +func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool) (err error) { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { // TODO(josebalius): remove special handling of this error here and it other places if err == errNoCodespaces { @@ -66,15 +63,15 @@ func ports(codespaceName string, asJSON bool) (err error) { return fmt.Errorf("error choosing codespace: %w", err) } - devContainerCh := getDevContainer(ctx, apiClient, codespace) + devContainerCh := getDevContainer(ctx, a.apiClient, codespace) - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } defer safeClose(session, &err) - log.Println("Loading ports...") + a.logger.Println("Loading ports...") ports, err := session.GetSharedServers(ctx) if err != nil { return fmt.Errorf("error getting ports of shared servers: %w", err) @@ -83,7 +80,7 @@ func ports(codespaceName string, asJSON bool) (err error) { devContainerResult := <-devContainerCh if devContainerResult.err != nil { // Warn about failure to read the devcontainer file. Not a ghcs command error. - _, _ = log.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) + _, _ = a.logger.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) } table := output.NewTable(os.Stdout, asJSON) @@ -122,7 +119,7 @@ type portAttribute struct { Label string `json:"label"` } -func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Codespace) <-chan devContainerResult { +func getDevContainer(ctx context.Context, apiClient apiClient, codespace *api.Codespace) <-chan devContainerResult { ch := make(chan devContainerResult, 1) go func() { contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json") @@ -154,7 +151,7 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod } // newPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public. -func newPortsPublicCmd() *cobra.Command { +func newPortsPublicCmd(app *App) *cobra.Command { return &cobra.Command{ Use: "public ", Short: "Mark port as public", @@ -168,14 +165,13 @@ func newPortsPublicCmd() *cobra.Command { return fmt.Errorf("get codespace flag: %w", err) } - log := output.NewLogger(os.Stdout, os.Stderr, false) - return updatePortVisibility(log, codespace, args[0], true) + return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], true) }, } } // newPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private. -func newPortsPrivateCmd() *cobra.Command { +func newPortsPrivateCmd(app *App) *cobra.Command { return &cobra.Command{ Use: "private ", Short: "Mark port as private", @@ -189,22 +185,18 @@ func newPortsPrivateCmd() *cobra.Command { return fmt.Errorf("get codespace flag: %w", err) } - log := output.NewLogger(os.Stdout, os.Stderr, false) - return updatePortVisibility(log, codespace, args[0], false) + return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], false) }, } } -func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) (err error) { - ctx := context.Background() - apiClient := api.New(GithubToken) - - user, err := apiClient.GetUser(ctx) +func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePort string, public bool) (err error) { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { if err == errNoCodespaces { return err @@ -212,7 +204,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -231,14 +223,14 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, if !public { state = "PRIVATE" } - log.Printf("Port %s is now %s.\n", sourcePort, state) + a.logger.Printf("Port %s is now %s.\n", sourcePort, state) return nil } // NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of // port pairs from the codespace to localhost. -func newPortsForwardCmd() *cobra.Command { +func newPortsForwardCmd(app *App) *cobra.Command { return &cobra.Command{ Use: "forward :...", Short: "Forward ports", @@ -252,27 +244,23 @@ func newPortsForwardCmd() *cobra.Command { return fmt.Errorf("get codespace flag: %w", err) } - log := output.NewLogger(os.Stdout, os.Stderr, false) - return forwardPorts(log, codespace, args) + return app.ForwardPorts(cmd.Context(), codespace, args) }, } } -func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err error) { - ctx := context.Background() - apiClient := api.New(GithubToken) - +func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []string) (err error) { portPairs, err := getPortPairs(ports) if err != nil { return fmt.Errorf("get port pairs: %w", err) } - user, err := apiClient.GetUser(ctx) + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { if err == errNoCodespaces { return err @@ -280,7 +268,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -297,7 +285,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err return err } defer listen.Close() - log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) + a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) name := fmt.Sprintf("share-%d", pair.remote) fwd := liveshare.NewPortForwarder(session, name, pair.remote) return fwd.ForwardToListener(ctx, listen) // error always non-nil diff --git a/cmd/ghcs/root.go b/cmd/ghcs/root.go index 6db4144a8..b71f4a0ff 100644 --- a/cmd/ghcs/root.go +++ b/cmd/ghcs/root.go @@ -1,10 +1,8 @@ package ghcs import ( - "errors" "fmt" "log" - "os" "strconv" "strings" @@ -15,10 +13,7 @@ import ( var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number. -// GithubToken is a temporary stopgap to make the token configurable by apps that import this package -var GithubToken = os.Getenv("GITHUB_TOKEN") - -func NewRootCmd() *cobra.Command { +func NewRootCmd(app *App) *cobra.Command { var lightstep string root := &cobra.Command{ @@ -32,28 +27,23 @@ token to access the GitHub API with.`, Version: version, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if os.Getenv("GITHUB_TOKEN") == "" { - return ErrTokenMissing - } return initLightstep(lightstep) }, } root.PersistentFlags().StringVar(&lightstep, "lightstep", "", "Lightstep tracing endpoint (service:token@host:port)") - root.AddCommand(newCodeCmd()) - root.AddCommand(newCreateCmd()) - root.AddCommand(newDeleteCmd()) - root.AddCommand(newListCmd()) - root.AddCommand(newLogsCmd()) - root.AddCommand(newPortsCmd()) - root.AddCommand(newSSHCmd()) + root.AddCommand(newCodeCmd(app)) + root.AddCommand(newCreateCmd(app)) + root.AddCommand(newDeleteCmd(app)) + root.AddCommand(newListCmd(app)) + root.AddCommand(newLogsCmd(app)) + root.AddCommand(newPortsCmd(app)) + root.AddCommand(newSSHCmd(app)) return root } -var ErrTokenMissing = errors.New("GITHUB_TOKEN is missing") - // initLightstep parses the --lightstep=service:token@host:port flag and // enables tracing if non-empty. func initLightstep(config string) error { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index bb771107a..bda7c28bb 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -4,16 +4,13 @@ import ( "context" "fmt" "net" - "os" - "github.com/github/ghcs/cmd/ghcs/output" - "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) -func newSSHCmd() *cobra.Command { +func newSSHCmd(app *App) *cobra.Command { var sshProfile, codespaceName string var sshServerPort int @@ -21,7 +18,7 @@ func newSSHCmd() *cobra.Command { Use: "ssh [flags] [--] [ssh-flags] [command]", Short: "SSH into a codespace", RunE: func(cmd *cobra.Command, args []string) error { - return ssh(context.Background(), args, sshProfile, codespaceName, sshServerPort) + return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort) }, } @@ -32,30 +29,28 @@ func newSSHCmd() *cobra.Command { return sshCmd } -func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { +// SSH opens an ssh session or runs an ssh command in a codespace. +func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() - apiClient := api.New(GithubToken) - log := output.NewLogger(os.Stdout, os.Stderr, false) - - user, err := apiClient.GetUser(ctx) + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } authkeys := make(chan error, 1) go func() { - authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) + authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login) }() - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -65,7 +60,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string return err } - log.Println("Fetching SSH Details...") + a.logger.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { return fmt.Errorf("error getting ssh server details: %w", err) @@ -86,7 +81,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string connectDestination = fmt.Sprintf("%s@localhost", sshUser) } - log.Println("Ready...") + a.logger.Println("Ready...") tunnelClosed := make(chan error, 1) go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) @@ -95,7 +90,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string shellClosed := make(chan error, 1) go func() { - shellClosed <- codespaces.Shell(ctx, log, sshArgs, localSSHServerPort, connectDestination, usingCustomPort) + shellClosed <- codespaces.Shell(ctx, a.logger, sshArgs, localSSHServerPort, connectDestination, usingCustomPort) }() select { diff --git a/internal/api/api.go b/internal/api/api.go index bfccfc6c9..4cce56894 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -45,14 +45,18 @@ const githubAPI = "https://api.github.com" type API struct { token string - client *http.Client + client httpClient githubAPI string } -func New(token string) *API { +type httpClient interface { + Do(req *http.Request) (*http.Response, error) +} + +func New(token string, httpClient httpClient) *API { return &API{ token: token, - client: &http.Client{}, + client: httpClient, githubAPI: githubAPI, } } @@ -272,6 +276,7 @@ func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) return nil, fmt.Errorf("error creating request: %w", err) } + // TODO: use a.setHeaders() req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { @@ -306,6 +311,7 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes return fmt.Errorf("error creating request: %w", err) } + // TODO: use a.setHeaders() req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/proxy/environments/*/start") if err != nil { @@ -417,14 +423,14 @@ type CreateCodespaceParams struct { Branch, Machine, Location string } -type logger interface { +type Logger interface { Print(v ...interface{}) (int, error) Println(v ...interface{}) (int, error) } // CreateCodespace creates a codespace with the given parameters and returns a non-nil error if it // fails to create. -func (a *API) CreateCodespace(ctx context.Context, log logger, params *CreateCodespaceParams) (*Codespace, error) { +func (a *API) CreateCodespace(ctx context.Context, log Logger, params *CreateCodespaceParams) (*Codespace, error) { codespace, err := a.startCreate( ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location, ) @@ -529,6 +535,7 @@ func (a *API) DeleteCodespace(ctx context.Context, user string, codespaceName st return fmt.Errorf("error creating request: %w", err) } + // TODO: use a.setHeaders() req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { @@ -628,6 +635,8 @@ func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http } func (a *API) setHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+a.token) + if a.token != "" { + req.Header.Set("Authorization", "Bearer "+a.token) + } req.Header.Set("Accept", "application/vnd.github.v3+json") } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 1cd605abc..f3cf71b51 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -23,9 +23,15 @@ func connectionReady(codespace *api.Codespace) bool { codespace.Environment.State == api.CodespaceEnvironmentStateAvailable } +type apiClient interface { + GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) + GetCodespaceToken(ctx context.Context, user, codespace string) (string, error) + StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error +} + // ConnectToLiveshare waits for a Codespace to become running, // and connects to it using a Live Share session. -func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { +func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { startedCodespace = true diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index c7d61b41e..0b395d6e3 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -36,7 +36,7 @@ type PostCreateState struct { // PollPostCreateStates watches for state changes in a codespace, // and calls the supplied poller for each batch of state changes. // It runs until it encounters an error, including cancellation of the context. -func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { +func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { return fmt.Errorf("getting codespace token: %w", err) From dc8f6ef183f6c4d7a0f4135376d54724302abb01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 24 Sep 2021 17:30:31 +0200 Subject: [PATCH 2/3] No longer accept a logger in CreateCodespace The API layer shouldn't concern itself with logging progress to stderr. Instead, we will subsequently add progress indicators in the caller around CreateCodespace and other potentially slow commands as needed. --- cmd/ghcs/common.go | 2 +- cmd/ghcs/create.go | 2 +- cmd/ghcs/mock_api.go | 14 ++++---------- internal/api/api.go | 8 +------- 4 files changed, 7 insertions(+), 19 deletions(-) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index e60fa7c96..fcdbb9f11 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -38,7 +38,7 @@ type apiClient interface { ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) DeleteCodespace(ctx context.Context, user, name string) error StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error - CreateCodespace(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) + CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) GetCodespaceRegionLocation(ctx context.Context) (string, error) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index c92a6edff..7e861e08d 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -80,7 +80,7 @@ func (a *App) Create(ctx context.Context, opts *createOptions) error { } a.logger.Print("Creating your codespace...") - codespace, err := a.apiClient.CreateCodespace(ctx, a.logger, &api.CreateCodespaceParams{ + codespace, err := a.apiClient.CreateCodespace(ctx, &api.CreateCodespaceParams{ User: userResult.User.Login, RepositoryID: repository.ID, Branch: branch, diff --git a/cmd/ghcs/mock_api.go b/cmd/ghcs/mock_api.go index 93abe7ed6..ef08c0a78 100644 --- a/cmd/ghcs/mock_api.go +++ b/cmd/ghcs/mock_api.go @@ -19,7 +19,7 @@ import ( // AuthorizedKeysFunc: func(ctx context.Context, user string) ([]byte, error) { // panic("mock out the AuthorizedKeys method") // }, -// CreateCodespaceFunc: func(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) { +// CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { // panic("mock out the CreateCodespace method") // }, // DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error { @@ -63,7 +63,7 @@ type apiClientMock struct { AuthorizedKeysFunc func(ctx context.Context, user string) ([]byte, error) // CreateCodespaceFunc mocks the CreateCodespace method. - CreateCodespaceFunc func(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) + CreateCodespaceFunc func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) // DeleteCodespaceFunc mocks the DeleteCodespace method. DeleteCodespaceFunc func(ctx context.Context, user string, name string) error @@ -108,8 +108,6 @@ type apiClientMock struct { CreateCodespace []struct { // Ctx is the ctx argument value. Ctx context.Context - // Logger is the logger argument value. - Logger api.Logger // Params is the params argument value. Params *api.CreateCodespaceParams } @@ -248,23 +246,21 @@ func (mock *apiClientMock) AuthorizedKeysCalls() []struct { } // CreateCodespace calls CreateCodespaceFunc. -func (mock *apiClientMock) CreateCodespace(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) { +func (mock *apiClientMock) CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { if mock.CreateCodespaceFunc == nil { panic("apiClientMock.CreateCodespaceFunc: method is nil but apiClient.CreateCodespace was just called") } callInfo := struct { Ctx context.Context - Logger api.Logger Params *api.CreateCodespaceParams }{ Ctx: ctx, - Logger: logger, Params: params, } mock.lockCreateCodespace.Lock() mock.calls.CreateCodespace = append(mock.calls.CreateCodespace, callInfo) mock.lockCreateCodespace.Unlock() - return mock.CreateCodespaceFunc(ctx, logger, params) + return mock.CreateCodespaceFunc(ctx, params) } // CreateCodespaceCalls gets all the calls that were made to CreateCodespace. @@ -272,12 +268,10 @@ func (mock *apiClientMock) CreateCodespace(ctx context.Context, logger api.Logge // len(mockedapiClient.CreateCodespaceCalls()) func (mock *apiClientMock) CreateCodespaceCalls() []struct { Ctx context.Context - Logger api.Logger Params *api.CreateCodespaceParams } { var calls []struct { Ctx context.Context - Logger api.Logger Params *api.CreateCodespaceParams } mock.lockCreateCodespace.RLock() diff --git a/internal/api/api.go b/internal/api/api.go index 4cce56894..efc24bcfb 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -423,14 +423,9 @@ type CreateCodespaceParams struct { Branch, Machine, Location string } -type Logger interface { - Print(v ...interface{}) (int, error) - Println(v ...interface{}) (int, error) -} - // CreateCodespace creates a codespace with the given parameters and returns a non-nil error if it // fails to create. -func (a *API) CreateCodespace(ctx context.Context, log Logger, params *CreateCodespaceParams) (*Codespace, error) { +func (a *API) CreateCodespace(ctx context.Context, params *CreateCodespaceParams) (*Codespace, error) { codespace, err := a.startCreate( ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location, ) @@ -452,7 +447,6 @@ func (a *API) CreateCodespace(ctx context.Context, log Logger, params *CreateCod case <-ctx.Done(): return nil, ctx.Err() case <-ticker.C: - log.Print(".") token, err := a.GetCodespaceToken(ctx, params.User, codespace.Name) if err != nil { if err == ErrNotProvisioned { From c82d4c54724d9d879350052d7f0c993d92ec13c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 24 Sep 2021 17:36:18 +0200 Subject: [PATCH 3/3] Avoid passing params struct as pointer --- cmd/ghcs/create.go | 4 ++-- cmd/ghcs/list.go | 14 +++++--------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 7e861e08d..7174e7721 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -23,7 +23,7 @@ type createOptions struct { } func newCreateCmd(app *App) *cobra.Command { - opts := &createOptions{} + opts := createOptions{} createCmd := &cobra.Command{ Use: "create", @@ -43,7 +43,7 @@ func newCreateCmd(app *App) *cobra.Command { } // Create creates a new Codespace -func (a *App) Create(ctx context.Context, opts *createOptions) error { +func (a *App) Create(ctx context.Context, opts createOptions) error { locationCh := getLocation(ctx, a.apiClient) userCh := getUser(ctx, a.apiClient) diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 842b9313d..1fc59cff0 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -10,28 +10,24 @@ import ( "github.com/spf13/cobra" ) -type listOptions struct { - asJSON bool -} - func newListCmd(app *App) *cobra.Command { - opts := &listOptions{} + var asJSON bool listCmd := &cobra.Command{ Use: "list", Short: "List your codespaces", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return app.List(cmd.Context(), opts) + return app.List(cmd.Context(), asJSON) }, } - listCmd.Flags().BoolVar(&opts.asJSON, "json", false, "Output as JSON") + listCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON") return listCmd } -func (a *App) List(ctx context.Context, opts *listOptions) error { +func (a *App) List(ctx context.Context, asJSON bool) error { user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) @@ -42,7 +38,7 @@ func (a *App) List(ctx context.Context, opts *listOptions) error { return fmt.Errorf("error getting codespaces: %w", err) } - table := output.NewTable(os.Stdout, opts.asJSON) + table := output.NewTable(os.Stdout, asJSON) table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"}) for _, codespace := range codespaces { table.Append([]string{