diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 08d2cff1a..cfcd989e2 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -5,12 +5,11 @@ import ( "fmt" "net/url" - "github.com/github/ghcs/internal/api" "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" ) -func newCodeCmd() *cobra.Command { +func newCodeCmd(app *App) *cobra.Command { var ( codespace string useInsiders bool @@ -21,7 +20,7 @@ func newCodeCmd() *cobra.Command { Short: "Open a codespace in VS Code", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return code(codespace, useInsiders) + return app.VSCode(cmd.Context(), codespace, useInsiders) }, } @@ -31,17 +30,15 @@ func newCodeCmd() *cobra.Command { return codeCmd } -func code(codespaceName string, useInsiders bool) error { - apiClient := api.New(GithubToken) - ctx := context.Background() - - user, err := apiClient.GetUser(ctx) +// VSCode opens a codespace in the local VS VSCode application. +func (a *App) VSCode(ctx context.Context, codespaceName string, useInsiders bool) error { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } if codespaceName == "" { - codespace, err := chooseCodespace(ctx, apiClient, user) + codespace, err := chooseCodespace(ctx, a.apiClient, user) if err != nil { if err == errNoCodespaces { return err diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index 371ca30b8..e60fa7c96 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -12,14 +12,43 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/AlecAivazis/survey/v2/terminal" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" "golang.org/x/term" ) +type App struct { + apiClient apiClient + logger *output.Logger +} + +func NewApp(logger *output.Logger, apiClient apiClient) *App { + return &App{ + apiClient: apiClient, + logger: logger, + } +} + +//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient +type apiClient interface { + GetUser(ctx context.Context) (*api.User, error) + GetCodespaceToken(ctx context.Context, user, name string) (string, error) + GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) + ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) + DeleteCodespace(ctx context.Context, user, name string) error + StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error + CreateCodespace(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) + GetRepository(ctx context.Context, nwo string) (*api.Repository, error) + AuthorizedKeys(ctx context.Context, user string) ([]byte, error) + GetCodespaceRegionLocation(ctx context.Context) (string, error) + GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch, location string) ([]*api.SKU, error) + GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) +} + var errNoCodespaces = errors.New("you have no codespaces") -func chooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) { +func chooseCodespace(ctx context.Context, apiClient apiClient, user *api.User) (*api.Codespace, error) { codespaces, err := apiClient.ListCodespaces(ctx, user.Login) if err != nil { return nil, fmt.Errorf("error getting codespaces: %w", err) @@ -68,7 +97,7 @@ func chooseCodespaceFromList(ctx context.Context, codespaces []*api.Codespace) ( // getOrChooseCodespace prompts the user to choose a codespace if the codespaceName is empty. // It then fetches the codespace token and the codespace record. -func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { +func getOrChooseCodespace(ctx context.Context, apiClient apiClient, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { if codespaceName == "" { codespace, err = chooseCodespace(ctx, apiClient, user) if err != nil { @@ -135,7 +164,7 @@ func ask(qs []*survey.Question, response interface{}) error { // checkAuthorizedKeys reports an error if the user has not registered any SSH keys; // see https://github.com/github/ghcs/issues/166#issuecomment-921769703. // The check is not required for security but it improves the error message. -func checkAuthorizedKeys(ctx context.Context, client *api.API, user string) error { +func checkAuthorizedKeys(ctx context.Context, client apiClient, user string) error { keys, err := client.AuthorizedKeys(ctx, user) if err != nil { return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 345489f6b..c92a6edff 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -22,7 +22,7 @@ type createOptions struct { showStatus bool } -func newCreateCmd() *cobra.Command { +func newCreateCmd(app *App) *cobra.Command { opts := &createOptions{} createCmd := &cobra.Command{ @@ -30,7 +30,7 @@ func newCreateCmd() *cobra.Command { Short: "Create a codespace", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return create(opts) + return app.Create(cmd.Context(), opts) }, } @@ -42,12 +42,10 @@ func newCreateCmd() *cobra.Command { return createCmd } -func create(opts *createOptions) error { - ctx := context.Background() - apiClient := api.New(GithubToken) - locationCh := getLocation(ctx, apiClient) - userCh := getUser(ctx, apiClient) - log := output.NewLogger(os.Stdout, os.Stderr, false) +// Create creates a new Codespace +func (a *App) Create(ctx context.Context, opts *createOptions) error { + locationCh := getLocation(ctx, a.apiClient) + userCh := getUser(ctx, a.apiClient) repo, err := getRepoName(opts.repo) if err != nil { @@ -58,7 +56,7 @@ func create(opts *createOptions) error { return fmt.Errorf("error getting branch name: %w", err) } - repository, err := apiClient.GetRepository(ctx, repo) + repository, err := a.apiClient.GetRepository(ctx, repo) if err != nil { return fmt.Errorf("error getting repository: %w", err) } @@ -73,7 +71,7 @@ func create(opts *createOptions) error { return fmt.Errorf("error getting codespace user: %w", userResult.Err) } - machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, apiClient) + machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, a.apiClient) if err != nil { return fmt.Errorf("error getting machine type: %w", err) } @@ -81,26 +79,26 @@ func create(opts *createOptions) error { return errors.New("there are no available machine types for this repository") } - log.Print("Creating your codespace...") - codespace, err := apiClient.CreateCodespace(ctx, log, &api.CreateCodespaceParams{ + a.logger.Print("Creating your codespace...") + codespace, err := a.apiClient.CreateCodespace(ctx, a.logger, &api.CreateCodespaceParams{ User: userResult.User.Login, RepositoryID: repository.ID, Branch: branch, Machine: machine, Location: locationResult.Location, }) - log.Print("\n") + a.logger.Print("\n") if err != nil { return fmt.Errorf("error creating codespace: %w", err) } if opts.showStatus { - if err := showStatus(ctx, log, apiClient, userResult.User, codespace); err != nil { + if err := showStatus(ctx, a.logger, a.apiClient, userResult.User, codespace); err != nil { return fmt.Errorf("show status: %w", err) } } - log.Printf("Codespace created: ") + a.logger.Printf("Codespace created: ") fmt.Fprintln(os.Stdout, codespace.Name) @@ -110,7 +108,7 @@ func create(opts *createOptions) error { // showStatus polls the codespace for a list of post create states and their status. It will keep polling // until all states have finished. Once all states have finished, we poll once more to check if any new // states have been introduced and stop polling otherwise. -func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, user *api.User, codespace *api.Codespace) error { +func showStatus(ctx context.Context, log *output.Logger, apiClient apiClient, user *api.User, codespace *api.Codespace) error { var lastState codespaces.PostCreateState var breakNextState bool @@ -177,7 +175,7 @@ type getUserResult struct { } // getUser fetches the user record associated with the GITHUB_TOKEN -func getUser(ctx context.Context, apiClient *api.API) <-chan getUserResult { +func getUser(ctx context.Context, apiClient apiClient) <-chan getUserResult { ch := make(chan getUserResult, 1) go func() { user, err := apiClient.GetUser(ctx) @@ -192,7 +190,7 @@ type locationResult struct { } // getLocation fetches the closest Codespace datacenter region/location to the user. -func getLocation(ctx context.Context, apiClient *api.API) <-chan locationResult { +func getLocation(ctx context.Context, apiClient apiClient) <-chan locationResult { ch := make(chan locationResult, 1) go func() { location, err := apiClient.GetCodespaceRegionLocation(ctx) @@ -236,7 +234,7 @@ func getBranchName(branch string) (string, error) { } // getMachineName prompts the user to select the machine type, or validates the machine if non-empty. -func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient *api.API) (string, error) { +func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient apiClient) (string, error) { skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location) if err != nil { return "", fmt.Errorf("error requesting machine instance types: %w", err) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index b5d25e7bb..d7fed4e68 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -4,12 +4,10 @@ import ( "context" "errors" "fmt" - "os" "strings" "time" "github.com/AlecAivazis/survey/v2" - "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" @@ -24,7 +22,6 @@ type deleteOptions struct { isInteractive bool now func() time.Time - apiClient apiClient prompter prompter } @@ -33,20 +30,10 @@ type prompter interface { Confirm(message string) (bool, error) } -//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient -type apiClient interface { - GetUser(ctx context.Context) (*api.User, error) - GetCodespaceToken(ctx context.Context, user, name string) (string, error) - GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) - ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) - DeleteCodespace(ctx context.Context, user, name string) error -} - -func newDeleteCmd() *cobra.Command { +func newDeleteCmd(app *App) *cobra.Command { opts := deleteOptions{ isInteractive: hasTTY, now: time.Now, - apiClient: api.New(os.Getenv("GITHUB_TOKEN")), prompter: &surveyPrompter{}, } @@ -58,8 +45,7 @@ func newDeleteCmd() *cobra.Command { if opts.deleteAll && opts.repoFilter != "" { return errors.New("both --all and --repo is not supported") } - log := output.NewLogger(cmd.OutOrStdout(), cmd.ErrOrStderr(), !opts.isInteractive) - return delete(context.Background(), log, opts) + return app.Delete(cmd.Context(), opts) }, } @@ -72,12 +58,8 @@ func newDeleteCmd() *cobra.Command { return deleteCmd } -type logger interface { - Errorf(format string, v ...interface{}) (int, error) -} - -func delete(ctx context.Context, log logger, opts deleteOptions) error { - user, err := opts.apiClient.GetUser(ctx) +func (a *App) Delete(ctx context.Context, opts deleteOptions) error { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } @@ -85,7 +67,7 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { var codespaces []*api.Codespace nameFilter := opts.codespaceName if nameFilter == "" { - codespaces, err = opts.apiClient.ListCodespaces(ctx, user.Login) + codespaces, err = a.apiClient.ListCodespaces(ctx, user.Login) if err != nil { return fmt.Errorf("error getting codespaces: %w", err) } @@ -99,12 +81,12 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { } } else { // TODO: this token is discarded and then re-requested later in DeleteCodespace - token, err := opts.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter) + token, err := a.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter) if err != nil { return fmt.Errorf("error getting codespace token: %w", err) } - codespace, err := opts.apiClient.GetCodespace(ctx, token, user.Login, nameFilter) + codespace, err := a.apiClient.GetCodespace(ctx, token, user.Login, nameFilter) if err != nil { return fmt.Errorf("error fetching codespace information: %w", err) } @@ -150,8 +132,8 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { for _, c := range codespacesToDelete { codespaceName := c.Name g.Go(func() error { - if err := opts.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil { - _, _ = log.Errorf("error deleting codespace %q: %v\n", codespaceName, err) + if err := a.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil { + _, _ = a.logger.Errorf("error deleting codespace %q: %v\n", codespaceName, err) return err } return nil diff --git a/cmd/ghcs/delete_test.go b/cmd/ghcs/delete_test.go index 47e6a4d6c..ab7b01d30 100644 --- a/cmd/ghcs/delete_test.go +++ b/cmd/ghcs/delete_test.go @@ -186,7 +186,6 @@ func TestDelete(t *testing.T) { } } opts := tt.opts - opts.apiClient = apiMock opts.now = func() time.Time { return now } opts.prompter = &prompterMock{ ConfirmFunc: func(msg string) (bool, error) { @@ -200,8 +199,11 @@ func TestDelete(t *testing.T) { stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} - log := output.NewLogger(stdout, stderr, false) - err := delete(context.Background(), log, opts) + app := &App{ + apiClient: apiMock, + logger: output.NewLogger(stdout, stderr, false), + } + err := app.Delete(context.Background(), opts) if (err != nil) != tt.wantErr { t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 065b7aa6d..842b9313d 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -14,7 +14,7 @@ type listOptions struct { asJSON bool } -func newListCmd() *cobra.Command { +func newListCmd(app *App) *cobra.Command { opts := &listOptions{} listCmd := &cobra.Command{ @@ -22,7 +22,7 @@ func newListCmd() *cobra.Command { Short: "List your codespaces", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return list(opts) + return app.List(cmd.Context(), opts) }, } @@ -31,16 +31,13 @@ func newListCmd() *cobra.Command { return listCmd } -func list(opts *listOptions) error { - apiClient := api.New(GithubToken) - ctx := context.Background() - - user, err := apiClient.GetUser(ctx) +func (a *App) List(ctx context.Context, opts *listOptions) error { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespaces, err := apiClient.ListCodespaces(ctx, user.Login) + codespaces, err := a.apiClient.ListCodespaces(ctx, user.Login) if err != nil { return fmt.Errorf("error getting codespaces: %w", err) } diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 0cddc6377..7f73d893c 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -4,29 +4,24 @@ import ( "context" "fmt" "net" - "os" - "github.com/github/ghcs/cmd/ghcs/output" - "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) -func newLogsCmd() *cobra.Command { +func newLogsCmd(app *App) *cobra.Command { var ( codespace string follow bool ) - log := output.NewLogger(os.Stdout, os.Stderr, false) - logsCmd := &cobra.Command{ Use: "logs", Short: "Access codespace logs", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return logs(context.Background(), log, codespace, follow) + return app.Logs(cmd.Context(), codespace, follow) }, } @@ -36,29 +31,27 @@ func newLogsCmd() *cobra.Command { return logsCmd } -func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) (err error) { +func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err error) { // Ensure all child tasks (port forwarding, remote exec) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() - apiClient := api.New(GithubToken) - - user, err := apiClient.GetUser(ctx) + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("getting user: %w", err) } authkeys := make(chan error, 1) go func() { - authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) + authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login) }() - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("connecting to Live Share: %w", err) } @@ -76,7 +69,7 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow defer listen.Close() localPort := listen.Addr().(*net.TCPAddr).Port - log.Println("Fetching SSH Details...") + a.logger.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { return fmt.Errorf("error getting ssh server details: %w", err) diff --git a/cmd/ghcs/main/main.go b/cmd/ghcs/main/main.go index 6b890d740..7c6b2a175 100644 --- a/cmd/ghcs/main/main.go +++ b/cmd/ghcs/main/main.go @@ -4,22 +4,45 @@ import ( "errors" "fmt" "io" + "net/http" "os" "github.com/github/ghcs/cmd/ghcs" + "github.com/github/ghcs/cmd/ghcs/output" + "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" ) func main() { - rootCmd := ghcs.NewRootCmd() + token := os.Getenv("GITHUB_TOKEN") + rootCmd := ghcs.NewRootCmd(ghcs.NewApp( + output.NewLogger(os.Stdout, os.Stderr, false), + api.New(token, http.DefaultClient), + )) + + // Require GITHUB_TOKEN through a Cobra pre-run hook so that Cobra's help system for commands can still + // function without the token set. + oldPreRun := rootCmd.PersistentPreRunE + rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { + if token == "" { + return errTokenMissing + } + if oldPreRun != nil { + return oldPreRun(cmd, args) + } + return nil + } + if cmd, err := rootCmd.ExecuteC(); err != nil { explainError(os.Stderr, err, cmd) os.Exit(1) } } +var errTokenMissing = errors.New("GITHUB_TOKEN is missing") + func explainError(w io.Writer, err error, cmd *cobra.Command) { - if errors.Is(err, ghcs.ErrTokenMissing) { + if errors.Is(err, errTokenMissing) { fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo") fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") return diff --git a/cmd/ghcs/mock_api.go b/cmd/ghcs/mock_api.go index 256a30ec3..93abe7ed6 100644 --- a/cmd/ghcs/mock_api.go +++ b/cmd/ghcs/mock_api.go @@ -16,21 +16,42 @@ import ( // // // make and configure a mocked apiClient // mockedapiClient := &apiClientMock{ +// AuthorizedKeysFunc: func(ctx context.Context, user string) ([]byte, error) { +// panic("mock out the AuthorizedKeys method") +// }, +// CreateCodespaceFunc: func(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) { +// panic("mock out the CreateCodespace method") +// }, // DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error { // panic("mock out the DeleteCodespace method") // }, // GetCodespaceFunc: func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) { // panic("mock out the GetCodespace method") // }, +// GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) { +// panic("mock out the GetCodespaceRegionLocation method") +// }, +// GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) { +// panic("mock out the GetCodespaceRepositoryContents method") +// }, // GetCodespaceTokenFunc: func(ctx context.Context, user string, name string) (string, error) { // panic("mock out the GetCodespaceToken method") // }, +// GetCodespacesSKUsFunc: func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) { +// panic("mock out the GetCodespacesSKUs method") +// }, +// GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { +// panic("mock out the GetRepository method") +// }, // GetUserFunc: func(ctx context.Context) (*api.User, error) { // panic("mock out the GetUser method") // }, // ListCodespacesFunc: func(ctx context.Context, user string) ([]*api.Codespace, error) { // panic("mock out the ListCodespaces method") // }, +// StartCodespaceFunc: func(ctx context.Context, token string, codespace *api.Codespace) error { +// panic("mock out the StartCodespace method") +// }, // } // // // use mockedapiClient in code that requires apiClient @@ -38,23 +59,60 @@ import ( // // } type apiClientMock struct { + // AuthorizedKeysFunc mocks the AuthorizedKeys method. + AuthorizedKeysFunc func(ctx context.Context, user string) ([]byte, error) + + // CreateCodespaceFunc mocks the CreateCodespace method. + CreateCodespaceFunc func(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) + // DeleteCodespaceFunc mocks the DeleteCodespace method. DeleteCodespaceFunc func(ctx context.Context, user string, name string) error // GetCodespaceFunc mocks the GetCodespace method. GetCodespaceFunc func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) + // GetCodespaceRegionLocationFunc mocks the GetCodespaceRegionLocation method. + GetCodespaceRegionLocationFunc func(ctx context.Context) (string, error) + + // GetCodespaceRepositoryContentsFunc mocks the GetCodespaceRepositoryContents method. + GetCodespaceRepositoryContentsFunc func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) + // GetCodespaceTokenFunc mocks the GetCodespaceToken method. GetCodespaceTokenFunc func(ctx context.Context, user string, name string) (string, error) + // GetCodespacesSKUsFunc mocks the GetCodespacesSKUs method. + GetCodespacesSKUsFunc func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) + + // GetRepositoryFunc mocks the GetRepository method. + GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error) + // GetUserFunc mocks the GetUser method. GetUserFunc func(ctx context.Context) (*api.User, error) // ListCodespacesFunc mocks the ListCodespaces method. ListCodespacesFunc func(ctx context.Context, user string) ([]*api.Codespace, error) + // StartCodespaceFunc mocks the StartCodespace method. + StartCodespaceFunc func(ctx context.Context, token string, codespace *api.Codespace) error + // calls tracks calls to the methods. calls struct { + // AuthorizedKeys holds details about calls to the AuthorizedKeys method. + AuthorizedKeys []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User string + } + // CreateCodespace holds details about calls to the CreateCodespace method. + CreateCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Logger is the logger argument value. + Logger api.Logger + // Params is the params argument value. + Params *api.CreateCodespaceParams + } // DeleteCodespace holds details about calls to the DeleteCodespace method. DeleteCodespace []struct { // Ctx is the ctx argument value. @@ -75,6 +133,20 @@ type apiClientMock struct { // Name is the name argument value. Name string } + // GetCodespaceRegionLocation holds details about calls to the GetCodespaceRegionLocation method. + GetCodespaceRegionLocation []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // GetCodespaceRepositoryContents holds details about calls to the GetCodespaceRepositoryContents method. + GetCodespaceRepositoryContents []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Codespace is the codespace argument value. + Codespace *api.Codespace + // Path is the path argument value. + Path string + } // GetCodespaceToken holds details about calls to the GetCodespaceToken method. GetCodespaceToken []struct { // Ctx is the ctx argument value. @@ -84,6 +156,26 @@ type apiClientMock struct { // Name is the name argument value. Name string } + // GetCodespacesSKUs holds details about calls to the GetCodespacesSKUs method. + GetCodespacesSKUs []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User *api.User + // Repository is the repository argument value. + Repository *api.Repository + // Branch is the branch argument value. + Branch string + // Location is the location argument value. + Location string + } + // GetRepository holds details about calls to the GetRepository method. + GetRepository []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Nwo is the nwo argument value. + Nwo string + } // GetUser holds details about calls to the GetUser method. GetUser []struct { // Ctx is the ctx argument value. @@ -96,12 +188,102 @@ type apiClientMock struct { // User is the user argument value. User string } + // StartCodespace holds details about calls to the StartCodespace method. + StartCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Token is the token argument value. + Token string + // Codespace is the codespace argument value. + Codespace *api.Codespace + } } - lockDeleteCodespace sync.RWMutex - lockGetCodespace sync.RWMutex - lockGetCodespaceToken sync.RWMutex - lockGetUser sync.RWMutex - lockListCodespaces sync.RWMutex + lockAuthorizedKeys sync.RWMutex + lockCreateCodespace sync.RWMutex + lockDeleteCodespace sync.RWMutex + lockGetCodespace sync.RWMutex + lockGetCodespaceRegionLocation sync.RWMutex + lockGetCodespaceRepositoryContents sync.RWMutex + lockGetCodespaceToken sync.RWMutex + lockGetCodespacesSKUs sync.RWMutex + lockGetRepository sync.RWMutex + lockGetUser sync.RWMutex + lockListCodespaces sync.RWMutex + lockStartCodespace sync.RWMutex +} + +// AuthorizedKeys calls AuthorizedKeysFunc. +func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) { + if mock.AuthorizedKeysFunc == nil { + panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called") + } + callInfo := struct { + Ctx context.Context + User string + }{ + Ctx: ctx, + User: user, + } + mock.lockAuthorizedKeys.Lock() + mock.calls.AuthorizedKeys = append(mock.calls.AuthorizedKeys, callInfo) + mock.lockAuthorizedKeys.Unlock() + return mock.AuthorizedKeysFunc(ctx, user) +} + +// AuthorizedKeysCalls gets all the calls that were made to AuthorizedKeys. +// Check the length with: +// len(mockedapiClient.AuthorizedKeysCalls()) +func (mock *apiClientMock) AuthorizedKeysCalls() []struct { + Ctx context.Context + User string +} { + var calls []struct { + Ctx context.Context + User string + } + mock.lockAuthorizedKeys.RLock() + calls = mock.calls.AuthorizedKeys + mock.lockAuthorizedKeys.RUnlock() + return calls +} + +// CreateCodespace calls CreateCodespaceFunc. +func (mock *apiClientMock) CreateCodespace(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) { + if mock.CreateCodespaceFunc == nil { + panic("apiClientMock.CreateCodespaceFunc: method is nil but apiClient.CreateCodespace was just called") + } + callInfo := struct { + Ctx context.Context + Logger api.Logger + Params *api.CreateCodespaceParams + }{ + Ctx: ctx, + Logger: logger, + Params: params, + } + mock.lockCreateCodespace.Lock() + mock.calls.CreateCodespace = append(mock.calls.CreateCodespace, callInfo) + mock.lockCreateCodespace.Unlock() + return mock.CreateCodespaceFunc(ctx, logger, params) +} + +// CreateCodespaceCalls gets all the calls that were made to CreateCodespace. +// Check the length with: +// len(mockedapiClient.CreateCodespaceCalls()) +func (mock *apiClientMock) CreateCodespaceCalls() []struct { + Ctx context.Context + Logger api.Logger + Params *api.CreateCodespaceParams +} { + var calls []struct { + Ctx context.Context + Logger api.Logger + Params *api.CreateCodespaceParams + } + mock.lockCreateCodespace.RLock() + calls = mock.calls.CreateCodespace + mock.lockCreateCodespace.RUnlock() + return calls } // DeleteCodespace calls DeleteCodespaceFunc. @@ -186,6 +368,76 @@ func (mock *apiClientMock) GetCodespaceCalls() []struct { return calls } +// GetCodespaceRegionLocation calls GetCodespaceRegionLocationFunc. +func (mock *apiClientMock) GetCodespaceRegionLocation(ctx context.Context) (string, error) { + if mock.GetCodespaceRegionLocationFunc == nil { + panic("apiClientMock.GetCodespaceRegionLocationFunc: method is nil but apiClient.GetCodespaceRegionLocation was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockGetCodespaceRegionLocation.Lock() + mock.calls.GetCodespaceRegionLocation = append(mock.calls.GetCodespaceRegionLocation, callInfo) + mock.lockGetCodespaceRegionLocation.Unlock() + return mock.GetCodespaceRegionLocationFunc(ctx) +} + +// GetCodespaceRegionLocationCalls gets all the calls that were made to GetCodespaceRegionLocation. +// Check the length with: +// len(mockedapiClient.GetCodespaceRegionLocationCalls()) +func (mock *apiClientMock) GetCodespaceRegionLocationCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockGetCodespaceRegionLocation.RLock() + calls = mock.calls.GetCodespaceRegionLocation + mock.lockGetCodespaceRegionLocation.RUnlock() + return calls +} + +// GetCodespaceRepositoryContents calls GetCodespaceRepositoryContentsFunc. +func (mock *apiClientMock) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) { + if mock.GetCodespaceRepositoryContentsFunc == nil { + panic("apiClientMock.GetCodespaceRepositoryContentsFunc: method is nil but apiClient.GetCodespaceRepositoryContents was just called") + } + callInfo := struct { + Ctx context.Context + Codespace *api.Codespace + Path string + }{ + Ctx: ctx, + Codespace: codespace, + Path: path, + } + mock.lockGetCodespaceRepositoryContents.Lock() + mock.calls.GetCodespaceRepositoryContents = append(mock.calls.GetCodespaceRepositoryContents, callInfo) + mock.lockGetCodespaceRepositoryContents.Unlock() + return mock.GetCodespaceRepositoryContentsFunc(ctx, codespace, path) +} + +// GetCodespaceRepositoryContentsCalls gets all the calls that were made to GetCodespaceRepositoryContents. +// Check the length with: +// len(mockedapiClient.GetCodespaceRepositoryContentsCalls()) +func (mock *apiClientMock) GetCodespaceRepositoryContentsCalls() []struct { + Ctx context.Context + Codespace *api.Codespace + Path string +} { + var calls []struct { + Ctx context.Context + Codespace *api.Codespace + Path string + } + mock.lockGetCodespaceRepositoryContents.RLock() + calls = mock.calls.GetCodespaceRepositoryContents + mock.lockGetCodespaceRepositoryContents.RUnlock() + return calls +} + // GetCodespaceToken calls GetCodespaceTokenFunc. func (mock *apiClientMock) GetCodespaceToken(ctx context.Context, user string, name string) (string, error) { if mock.GetCodespaceTokenFunc == nil { @@ -225,6 +477,88 @@ func (mock *apiClientMock) GetCodespaceTokenCalls() []struct { return calls } +// GetCodespacesSKUs calls GetCodespacesSKUsFunc. +func (mock *apiClientMock) GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) { + if mock.GetCodespacesSKUsFunc == nil { + panic("apiClientMock.GetCodespacesSKUsFunc: method is nil but apiClient.GetCodespacesSKUs was just called") + } + callInfo := struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string + }{ + Ctx: ctx, + User: user, + Repository: repository, + Branch: branch, + Location: location, + } + mock.lockGetCodespacesSKUs.Lock() + mock.calls.GetCodespacesSKUs = append(mock.calls.GetCodespacesSKUs, callInfo) + mock.lockGetCodespacesSKUs.Unlock() + return mock.GetCodespacesSKUsFunc(ctx, user, repository, branch, location) +} + +// GetCodespacesSKUsCalls gets all the calls that were made to GetCodespacesSKUs. +// Check the length with: +// len(mockedapiClient.GetCodespacesSKUsCalls()) +func (mock *apiClientMock) GetCodespacesSKUsCalls() []struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string +} { + var calls []struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string + } + mock.lockGetCodespacesSKUs.RLock() + calls = mock.calls.GetCodespacesSKUs + mock.lockGetCodespacesSKUs.RUnlock() + return calls +} + +// GetRepository calls GetRepositoryFunc. +func (mock *apiClientMock) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) { + if mock.GetRepositoryFunc == nil { + panic("apiClientMock.GetRepositoryFunc: method is nil but apiClient.GetRepository was just called") + } + callInfo := struct { + Ctx context.Context + Nwo string + }{ + Ctx: ctx, + Nwo: nwo, + } + mock.lockGetRepository.Lock() + mock.calls.GetRepository = append(mock.calls.GetRepository, callInfo) + mock.lockGetRepository.Unlock() + return mock.GetRepositoryFunc(ctx, nwo) +} + +// GetRepositoryCalls gets all the calls that were made to GetRepository. +// Check the length with: +// len(mockedapiClient.GetRepositoryCalls()) +func (mock *apiClientMock) GetRepositoryCalls() []struct { + Ctx context.Context + Nwo string +} { + var calls []struct { + Ctx context.Context + Nwo string + } + mock.lockGetRepository.RLock() + calls = mock.calls.GetRepository + mock.lockGetRepository.RUnlock() + return calls +} + // GetUser calls GetUserFunc. func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) { if mock.GetUserFunc == nil { @@ -290,3 +624,42 @@ func (mock *apiClientMock) ListCodespacesCalls() []struct { mock.lockListCodespaces.RUnlock() return calls } + +// StartCodespace calls StartCodespaceFunc. +func (mock *apiClientMock) StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error { + if mock.StartCodespaceFunc == nil { + panic("apiClientMock.StartCodespaceFunc: method is nil but apiClient.StartCodespace was just called") + } + callInfo := struct { + Ctx context.Context + Token string + Codespace *api.Codespace + }{ + Ctx: ctx, + Token: token, + Codespace: codespace, + } + mock.lockStartCodespace.Lock() + mock.calls.StartCodespace = append(mock.calls.StartCodespace, callInfo) + mock.lockStartCodespace.Unlock() + return mock.StartCodespaceFunc(ctx, token, codespace) +} + +// StartCodespaceCalls gets all the calls that were made to StartCodespace. +// Check the length with: +// len(mockedapiClient.StartCodespaceCalls()) +func (mock *apiClientMock) StartCodespaceCalls() []struct { + Ctx context.Context + Token string + Codespace *api.Codespace +} { + var calls []struct { + Ctx context.Context + Token string + Codespace *api.Codespace + } + mock.lockStartCodespace.RLock() + calls = mock.calls.StartCodespace + mock.lockStartCodespace.RUnlock() + return calls +} diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index f423245bd..06eabad6d 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -22,7 +22,7 @@ import ( // newPortsCmd returns a Cobra "ports" command that displays a table of available ports, // according to the specified flags. -func newPortsCmd() *cobra.Command { +func newPortsCmd(app *App) *cobra.Command { var ( codespace string asJSON bool @@ -33,31 +33,28 @@ func newPortsCmd() *cobra.Command { Short: "List ports in a codespace", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return ports(codespace, asJSON) + return app.ListPorts(cmd.Context(), codespace, asJSON) }, } portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") portsCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON") - portsCmd.AddCommand(newPortsPublicCmd()) - portsCmd.AddCommand(newPortsPrivateCmd()) - portsCmd.AddCommand(newPortsForwardCmd()) + portsCmd.AddCommand(newPortsPublicCmd(app)) + portsCmd.AddCommand(newPortsPrivateCmd(app)) + portsCmd.AddCommand(newPortsForwardCmd(app)) return portsCmd } -func ports(codespaceName string, asJSON bool) (err error) { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) - ctx := context.Background() - log := output.NewLogger(os.Stdout, os.Stderr, asJSON) - - user, err := apiClient.GetUser(ctx) +// ListPorts lists known ports in a codespace. +func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool) (err error) { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { // TODO(josebalius): remove special handling of this error here and it other places if err == errNoCodespaces { @@ -66,15 +63,15 @@ func ports(codespaceName string, asJSON bool) (err error) { return fmt.Errorf("error choosing codespace: %w", err) } - devContainerCh := getDevContainer(ctx, apiClient, codespace) + devContainerCh := getDevContainer(ctx, a.apiClient, codespace) - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } defer safeClose(session, &err) - log.Println("Loading ports...") + a.logger.Println("Loading ports...") ports, err := session.GetSharedServers(ctx) if err != nil { return fmt.Errorf("error getting ports of shared servers: %w", err) @@ -83,7 +80,7 @@ func ports(codespaceName string, asJSON bool) (err error) { devContainerResult := <-devContainerCh if devContainerResult.err != nil { // Warn about failure to read the devcontainer file. Not a ghcs command error. - _, _ = log.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) + _, _ = a.logger.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) } table := output.NewTable(os.Stdout, asJSON) @@ -122,7 +119,7 @@ type portAttribute struct { Label string `json:"label"` } -func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Codespace) <-chan devContainerResult { +func getDevContainer(ctx context.Context, apiClient apiClient, codespace *api.Codespace) <-chan devContainerResult { ch := make(chan devContainerResult, 1) go func() { contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json") @@ -154,7 +151,7 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod } // newPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public. -func newPortsPublicCmd() *cobra.Command { +func newPortsPublicCmd(app *App) *cobra.Command { return &cobra.Command{ Use: "public ", Short: "Mark port as public", @@ -168,14 +165,13 @@ func newPortsPublicCmd() *cobra.Command { return fmt.Errorf("get codespace flag: %w", err) } - log := output.NewLogger(os.Stdout, os.Stderr, false) - return updatePortVisibility(log, codespace, args[0], true) + return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], true) }, } } // newPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private. -func newPortsPrivateCmd() *cobra.Command { +func newPortsPrivateCmd(app *App) *cobra.Command { return &cobra.Command{ Use: "private ", Short: "Mark port as private", @@ -189,22 +185,18 @@ func newPortsPrivateCmd() *cobra.Command { return fmt.Errorf("get codespace flag: %w", err) } - log := output.NewLogger(os.Stdout, os.Stderr, false) - return updatePortVisibility(log, codespace, args[0], false) + return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], false) }, } } -func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) (err error) { - ctx := context.Background() - apiClient := api.New(GithubToken) - - user, err := apiClient.GetUser(ctx) +func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePort string, public bool) (err error) { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { if err == errNoCodespaces { return err @@ -212,7 +204,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -231,14 +223,14 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, if !public { state = "PRIVATE" } - log.Printf("Port %s is now %s.\n", sourcePort, state) + a.logger.Printf("Port %s is now %s.\n", sourcePort, state) return nil } // NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of // port pairs from the codespace to localhost. -func newPortsForwardCmd() *cobra.Command { +func newPortsForwardCmd(app *App) *cobra.Command { return &cobra.Command{ Use: "forward :...", Short: "Forward ports", @@ -252,27 +244,23 @@ func newPortsForwardCmd() *cobra.Command { return fmt.Errorf("get codespace flag: %w", err) } - log := output.NewLogger(os.Stdout, os.Stderr, false) - return forwardPorts(log, codespace, args) + return app.ForwardPorts(cmd.Context(), codespace, args) }, } } -func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err error) { - ctx := context.Background() - apiClient := api.New(GithubToken) - +func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []string) (err error) { portPairs, err := getPortPairs(ports) if err != nil { return fmt.Errorf("get port pairs: %w", err) } - user, err := apiClient.GetUser(ctx) + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { if err == errNoCodespaces { return err @@ -280,7 +268,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -297,7 +285,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err return err } defer listen.Close() - log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) + a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) name := fmt.Sprintf("share-%d", pair.remote) fwd := liveshare.NewPortForwarder(session, name, pair.remote) return fwd.ForwardToListener(ctx, listen) // error always non-nil diff --git a/cmd/ghcs/root.go b/cmd/ghcs/root.go index 6db4144a8..b71f4a0ff 100644 --- a/cmd/ghcs/root.go +++ b/cmd/ghcs/root.go @@ -1,10 +1,8 @@ package ghcs import ( - "errors" "fmt" "log" - "os" "strconv" "strings" @@ -15,10 +13,7 @@ import ( var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number. -// GithubToken is a temporary stopgap to make the token configurable by apps that import this package -var GithubToken = os.Getenv("GITHUB_TOKEN") - -func NewRootCmd() *cobra.Command { +func NewRootCmd(app *App) *cobra.Command { var lightstep string root := &cobra.Command{ @@ -32,28 +27,23 @@ token to access the GitHub API with.`, Version: version, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if os.Getenv("GITHUB_TOKEN") == "" { - return ErrTokenMissing - } return initLightstep(lightstep) }, } root.PersistentFlags().StringVar(&lightstep, "lightstep", "", "Lightstep tracing endpoint (service:token@host:port)") - root.AddCommand(newCodeCmd()) - root.AddCommand(newCreateCmd()) - root.AddCommand(newDeleteCmd()) - root.AddCommand(newListCmd()) - root.AddCommand(newLogsCmd()) - root.AddCommand(newPortsCmd()) - root.AddCommand(newSSHCmd()) + root.AddCommand(newCodeCmd(app)) + root.AddCommand(newCreateCmd(app)) + root.AddCommand(newDeleteCmd(app)) + root.AddCommand(newListCmd(app)) + root.AddCommand(newLogsCmd(app)) + root.AddCommand(newPortsCmd(app)) + root.AddCommand(newSSHCmd(app)) return root } -var ErrTokenMissing = errors.New("GITHUB_TOKEN is missing") - // initLightstep parses the --lightstep=service:token@host:port flag and // enables tracing if non-empty. func initLightstep(config string) error { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index bb771107a..bda7c28bb 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -4,16 +4,13 @@ import ( "context" "fmt" "net" - "os" - "github.com/github/ghcs/cmd/ghcs/output" - "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) -func newSSHCmd() *cobra.Command { +func newSSHCmd(app *App) *cobra.Command { var sshProfile, codespaceName string var sshServerPort int @@ -21,7 +18,7 @@ func newSSHCmd() *cobra.Command { Use: "ssh [flags] [--] [ssh-flags] [command]", Short: "SSH into a codespace", RunE: func(cmd *cobra.Command, args []string) error { - return ssh(context.Background(), args, sshProfile, codespaceName, sshServerPort) + return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort) }, } @@ -32,30 +29,28 @@ func newSSHCmd() *cobra.Command { return sshCmd } -func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { +// SSH opens an ssh session or runs an ssh command in a codespace. +func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() - apiClient := api.New(GithubToken) - log := output.NewLogger(os.Stdout, os.Stderr, false) - - user, err := apiClient.GetUser(ctx) + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } authkeys := make(chan error, 1) go func() { - authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) + authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login) }() - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -65,7 +60,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string return err } - log.Println("Fetching SSH Details...") + a.logger.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { return fmt.Errorf("error getting ssh server details: %w", err) @@ -86,7 +81,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string connectDestination = fmt.Sprintf("%s@localhost", sshUser) } - log.Println("Ready...") + a.logger.Println("Ready...") tunnelClosed := make(chan error, 1) go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) @@ -95,7 +90,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string shellClosed := make(chan error, 1) go func() { - shellClosed <- codespaces.Shell(ctx, log, sshArgs, localSSHServerPort, connectDestination, usingCustomPort) + shellClosed <- codespaces.Shell(ctx, a.logger, sshArgs, localSSHServerPort, connectDestination, usingCustomPort) }() select { diff --git a/internal/api/api.go b/internal/api/api.go index bfccfc6c9..4cce56894 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -45,14 +45,18 @@ const githubAPI = "https://api.github.com" type API struct { token string - client *http.Client + client httpClient githubAPI string } -func New(token string) *API { +type httpClient interface { + Do(req *http.Request) (*http.Response, error) +} + +func New(token string, httpClient httpClient) *API { return &API{ token: token, - client: &http.Client{}, + client: httpClient, githubAPI: githubAPI, } } @@ -272,6 +276,7 @@ func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) return nil, fmt.Errorf("error creating request: %w", err) } + // TODO: use a.setHeaders() req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { @@ -306,6 +311,7 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes return fmt.Errorf("error creating request: %w", err) } + // TODO: use a.setHeaders() req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/proxy/environments/*/start") if err != nil { @@ -417,14 +423,14 @@ type CreateCodespaceParams struct { Branch, Machine, Location string } -type logger interface { +type Logger interface { Print(v ...interface{}) (int, error) Println(v ...interface{}) (int, error) } // CreateCodespace creates a codespace with the given parameters and returns a non-nil error if it // fails to create. -func (a *API) CreateCodespace(ctx context.Context, log logger, params *CreateCodespaceParams) (*Codespace, error) { +func (a *API) CreateCodespace(ctx context.Context, log Logger, params *CreateCodespaceParams) (*Codespace, error) { codespace, err := a.startCreate( ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location, ) @@ -529,6 +535,7 @@ func (a *API) DeleteCodespace(ctx context.Context, user string, codespaceName st return fmt.Errorf("error creating request: %w", err) } + // TODO: use a.setHeaders() req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { @@ -628,6 +635,8 @@ func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http } func (a *API) setHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+a.token) + if a.token != "" { + req.Header.Set("Authorization", "Bearer "+a.token) + } req.Header.Set("Accept", "application/vnd.github.v3+json") } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 1cd605abc..f3cf71b51 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -23,9 +23,15 @@ func connectionReady(codespace *api.Codespace) bool { codespace.Environment.State == api.CodespaceEnvironmentStateAvailable } +type apiClient interface { + GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) + GetCodespaceToken(ctx context.Context, user, codespace string) (string, error) + StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error +} + // ConnectToLiveshare waits for a Codespace to become running, // and connects to it using a Live Share session. -func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { +func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { startedCodespace = true diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index c7d61b41e..0b395d6e3 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -36,7 +36,7 @@ type PostCreateState struct { // PollPostCreateStates watches for state changes in a codespace, // and calls the supplied poller for each batch of state changes. // It runs until it encounters an error, including cancellation of the context. -func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { +func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { return fmt.Errorf("getting codespace token: %w", err)