Merge pull request #190 from github/app-struct

Introduce an App struct that executes core business logic
This commit is contained in:
Mislav Marohnić 2021-09-28 15:46:57 +02:00 committed by GitHub
commit 0483765da5
15 changed files with 554 additions and 188 deletions

View file

@ -5,12 +5,11 @@ import (
"fmt"
"net/url"
"github.com/github/ghcs/internal/api"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
)
func newCodeCmd() *cobra.Command {
func newCodeCmd(app *App) *cobra.Command {
var (
codespace string
useInsiders bool
@ -21,7 +20,7 @@ func newCodeCmd() *cobra.Command {
Short: "Open a codespace in VS Code",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return code(codespace, useInsiders)
return app.VSCode(cmd.Context(), codespace, useInsiders)
},
}
@ -31,17 +30,15 @@ func newCodeCmd() *cobra.Command {
return codeCmd
}
func code(codespaceName string, useInsiders bool) error {
apiClient := api.New(GithubToken)
ctx := context.Background()
user, err := apiClient.GetUser(ctx)
// VSCode opens a codespace in the local VS VSCode application.
func (a *App) VSCode(ctx context.Context, codespaceName string, useInsiders bool) error {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
if codespaceName == "" {
codespace, err := chooseCodespace(ctx, apiClient, user)
codespace, err := chooseCodespace(ctx, a.apiClient, user)
if err != nil {
if err == errNoCodespaces {
return err

View file

@ -12,14 +12,43 @@ import (
"github.com/AlecAivazis/survey/v2"
"github.com/AlecAivazis/survey/v2/terminal"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/spf13/cobra"
"golang.org/x/term"
)
type App struct {
apiClient apiClient
logger *output.Logger
}
func NewApp(logger *output.Logger, apiClient apiClient) *App {
return &App{
apiClient: apiClient,
logger: logger,
}
}
//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient
type apiClient interface {
GetUser(ctx context.Context) (*api.User, error)
GetCodespaceToken(ctx context.Context, user, name string) (string, error)
GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error)
ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error)
DeleteCodespace(ctx context.Context, user, name string) error
StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error
CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
GetRepository(ctx context.Context, nwo string) (*api.Repository, error)
AuthorizedKeys(ctx context.Context, user string) ([]byte, error)
GetCodespaceRegionLocation(ctx context.Context) (string, error)
GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch, location string) ([]*api.SKU, error)
GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
}
var errNoCodespaces = errors.New("you have no codespaces")
func chooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) {
func chooseCodespace(ctx context.Context, apiClient apiClient, user *api.User) (*api.Codespace, error) {
codespaces, err := apiClient.ListCodespaces(ctx, user.Login)
if err != nil {
return nil, fmt.Errorf("error getting codespaces: %w", err)
@ -68,7 +97,7 @@ func chooseCodespaceFromList(ctx context.Context, codespaces []*api.Codespace) (
// getOrChooseCodespace prompts the user to choose a codespace if the codespaceName is empty.
// It then fetches the codespace token and the codespace record.
func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) {
func getOrChooseCodespace(ctx context.Context, apiClient apiClient, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) {
if codespaceName == "" {
codespace, err = chooseCodespace(ctx, apiClient, user)
if err != nil {
@ -135,7 +164,7 @@ func ask(qs []*survey.Question, response interface{}) error {
// checkAuthorizedKeys reports an error if the user has not registered any SSH keys;
// see https://github.com/github/ghcs/issues/166#issuecomment-921769703.
// The check is not required for security but it improves the error message.
func checkAuthorizedKeys(ctx context.Context, client *api.API, user string) error {
func checkAuthorizedKeys(ctx context.Context, client apiClient, user string) error {
keys, err := client.AuthorizedKeys(ctx, user)
if err != nil {
return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err)

View file

@ -22,15 +22,15 @@ type createOptions struct {
showStatus bool
}
func newCreateCmd() *cobra.Command {
opts := &createOptions{}
func newCreateCmd(app *App) *cobra.Command {
opts := createOptions{}
createCmd := &cobra.Command{
Use: "create",
Short: "Create a codespace",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return create(opts)
return app.Create(cmd.Context(), opts)
},
}
@ -42,12 +42,10 @@ func newCreateCmd() *cobra.Command {
return createCmd
}
func create(opts *createOptions) error {
ctx := context.Background()
apiClient := api.New(GithubToken)
locationCh := getLocation(ctx, apiClient)
userCh := getUser(ctx, apiClient)
log := output.NewLogger(os.Stdout, os.Stderr, false)
// Create creates a new Codespace
func (a *App) Create(ctx context.Context, opts createOptions) error {
locationCh := getLocation(ctx, a.apiClient)
userCh := getUser(ctx, a.apiClient)
repo, err := getRepoName(opts.repo)
if err != nil {
@ -58,7 +56,7 @@ func create(opts *createOptions) error {
return fmt.Errorf("error getting branch name: %w", err)
}
repository, err := apiClient.GetRepository(ctx, repo)
repository, err := a.apiClient.GetRepository(ctx, repo)
if err != nil {
return fmt.Errorf("error getting repository: %w", err)
}
@ -73,7 +71,7 @@ func create(opts *createOptions) error {
return fmt.Errorf("error getting codespace user: %w", userResult.Err)
}
machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, apiClient)
machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, a.apiClient)
if err != nil {
return fmt.Errorf("error getting machine type: %w", err)
}
@ -81,26 +79,26 @@ func create(opts *createOptions) error {
return errors.New("there are no available machine types for this repository")
}
log.Print("Creating your codespace...")
codespace, err := apiClient.CreateCodespace(ctx, log, &api.CreateCodespaceParams{
a.logger.Print("Creating your codespace...")
codespace, err := a.apiClient.CreateCodespace(ctx, &api.CreateCodespaceParams{
User: userResult.User.Login,
RepositoryID: repository.ID,
Branch: branch,
Machine: machine,
Location: locationResult.Location,
})
log.Print("\n")
a.logger.Print("\n")
if err != nil {
return fmt.Errorf("error creating codespace: %w", err)
}
if opts.showStatus {
if err := showStatus(ctx, log, apiClient, userResult.User, codespace); err != nil {
if err := showStatus(ctx, a.logger, a.apiClient, userResult.User, codespace); err != nil {
return fmt.Errorf("show status: %w", err)
}
}
log.Printf("Codespace created: ")
a.logger.Printf("Codespace created: ")
fmt.Fprintln(os.Stdout, codespace.Name)
@ -110,7 +108,7 @@ func create(opts *createOptions) error {
// showStatus polls the codespace for a list of post create states and their status. It will keep polling
// until all states have finished. Once all states have finished, we poll once more to check if any new
// states have been introduced and stop polling otherwise.
func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, user *api.User, codespace *api.Codespace) error {
func showStatus(ctx context.Context, log *output.Logger, apiClient apiClient, user *api.User, codespace *api.Codespace) error {
var lastState codespaces.PostCreateState
var breakNextState bool
@ -177,7 +175,7 @@ type getUserResult struct {
}
// getUser fetches the user record associated with the GITHUB_TOKEN
func getUser(ctx context.Context, apiClient *api.API) <-chan getUserResult {
func getUser(ctx context.Context, apiClient apiClient) <-chan getUserResult {
ch := make(chan getUserResult, 1)
go func() {
user, err := apiClient.GetUser(ctx)
@ -192,7 +190,7 @@ type locationResult struct {
}
// getLocation fetches the closest Codespace datacenter region/location to the user.
func getLocation(ctx context.Context, apiClient *api.API) <-chan locationResult {
func getLocation(ctx context.Context, apiClient apiClient) <-chan locationResult {
ch := make(chan locationResult, 1)
go func() {
location, err := apiClient.GetCodespaceRegionLocation(ctx)
@ -236,7 +234,7 @@ func getBranchName(branch string) (string, error) {
}
// getMachineName prompts the user to select the machine type, or validates the machine if non-empty.
func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient *api.API) (string, error) {
func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient apiClient) (string, error) {
skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location)
if err != nil {
return "", fmt.Errorf("error requesting machine instance types: %w", err)

View file

@ -4,12 +4,10 @@ import (
"context"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/AlecAivazis/survey/v2"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
@ -24,7 +22,6 @@ type deleteOptions struct {
isInteractive bool
now func() time.Time
apiClient apiClient
prompter prompter
}
@ -33,20 +30,10 @@ type prompter interface {
Confirm(message string) (bool, error)
}
//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient
type apiClient interface {
GetUser(ctx context.Context) (*api.User, error)
GetCodespaceToken(ctx context.Context, user, name string) (string, error)
GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error)
ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error)
DeleteCodespace(ctx context.Context, user, name string) error
}
func newDeleteCmd() *cobra.Command {
func newDeleteCmd(app *App) *cobra.Command {
opts := deleteOptions{
isInteractive: hasTTY,
now: time.Now,
apiClient: api.New(os.Getenv("GITHUB_TOKEN")),
prompter: &surveyPrompter{},
}
@ -58,8 +45,7 @@ func newDeleteCmd() *cobra.Command {
if opts.deleteAll && opts.repoFilter != "" {
return errors.New("both --all and --repo is not supported")
}
log := output.NewLogger(cmd.OutOrStdout(), cmd.ErrOrStderr(), !opts.isInteractive)
return delete(context.Background(), log, opts)
return app.Delete(cmd.Context(), opts)
},
}
@ -72,12 +58,8 @@ func newDeleteCmd() *cobra.Command {
return deleteCmd
}
type logger interface {
Errorf(format string, v ...interface{}) (int, error)
}
func delete(ctx context.Context, log logger, opts deleteOptions) error {
user, err := opts.apiClient.GetUser(ctx)
func (a *App) Delete(ctx context.Context, opts deleteOptions) error {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
@ -85,7 +67,7 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error {
var codespaces []*api.Codespace
nameFilter := opts.codespaceName
if nameFilter == "" {
codespaces, err = opts.apiClient.ListCodespaces(ctx, user.Login)
codespaces, err = a.apiClient.ListCodespaces(ctx, user.Login)
if err != nil {
return fmt.Errorf("error getting codespaces: %w", err)
}
@ -99,12 +81,12 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error {
}
} else {
// TODO: this token is discarded and then re-requested later in DeleteCodespace
token, err := opts.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter)
token, err := a.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter)
if err != nil {
return fmt.Errorf("error getting codespace token: %w", err)
}
codespace, err := opts.apiClient.GetCodespace(ctx, token, user.Login, nameFilter)
codespace, err := a.apiClient.GetCodespace(ctx, token, user.Login, nameFilter)
if err != nil {
return fmt.Errorf("error fetching codespace information: %w", err)
}
@ -150,8 +132,8 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error {
for _, c := range codespacesToDelete {
codespaceName := c.Name
g.Go(func() error {
if err := opts.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil {
_, _ = log.Errorf("error deleting codespace %q: %v\n", codespaceName, err)
if err := a.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil {
_, _ = a.logger.Errorf("error deleting codespace %q: %v\n", codespaceName, err)
return err
}
return nil

View file

@ -186,7 +186,6 @@ func TestDelete(t *testing.T) {
}
}
opts := tt.opts
opts.apiClient = apiMock
opts.now = func() time.Time { return now }
opts.prompter = &prompterMock{
ConfirmFunc: func(msg string) (bool, error) {
@ -200,8 +199,11 @@ func TestDelete(t *testing.T) {
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
log := output.NewLogger(stdout, stderr, false)
err := delete(context.Background(), log, opts)
app := &App{
apiClient: apiMock,
logger: output.NewLogger(stdout, stderr, false),
}
err := app.Delete(context.Background(), opts)
if (err != nil) != tt.wantErr {
t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr)
}

View file

@ -10,42 +10,35 @@ import (
"github.com/spf13/cobra"
)
type listOptions struct {
asJSON bool
}
func newListCmd() *cobra.Command {
opts := &listOptions{}
func newListCmd(app *App) *cobra.Command {
var asJSON bool
listCmd := &cobra.Command{
Use: "list",
Short: "List your codespaces",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return list(opts)
return app.List(cmd.Context(), asJSON)
},
}
listCmd.Flags().BoolVar(&opts.asJSON, "json", false, "Output as JSON")
listCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON")
return listCmd
}
func list(opts *listOptions) error {
apiClient := api.New(GithubToken)
ctx := context.Background()
user, err := apiClient.GetUser(ctx)
func (a *App) List(ctx context.Context, asJSON bool) error {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespaces, err := apiClient.ListCodespaces(ctx, user.Login)
codespaces, err := a.apiClient.ListCodespaces(ctx, user.Login)
if err != nil {
return fmt.Errorf("error getting codespaces: %w", err)
}
table := output.NewTable(os.Stdout, opts.asJSON)
table := output.NewTable(os.Stdout, asJSON)
table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"})
for _, codespace := range codespaces {
table.Append([]string{

View file

@ -4,29 +4,24 @@ import (
"context"
"fmt"
"net"
"os"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/github/ghcs/internal/codespaces"
"github.com/github/ghcs/internal/liveshare"
"github.com/spf13/cobra"
)
func newLogsCmd() *cobra.Command {
func newLogsCmd(app *App) *cobra.Command {
var (
codespace string
follow bool
)
log := output.NewLogger(os.Stdout, os.Stderr, false)
logsCmd := &cobra.Command{
Use: "logs",
Short: "Access codespace logs",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return logs(context.Background(), log, codespace, follow)
return app.Logs(cmd.Context(), codespace, follow)
},
}
@ -36,29 +31,27 @@ func newLogsCmd() *cobra.Command {
return logsCmd
}
func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) (err error) {
func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err error) {
// Ensure all child tasks (port forwarding, remote exec) terminate before return.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
apiClient := api.New(GithubToken)
user, err := apiClient.GetUser(ctx)
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("getting user: %w", err)
}
authkeys := make(chan error, 1)
go func() {
authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login)
authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login)
}()
codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName)
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
return fmt.Errorf("get or choose codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("connecting to Live Share: %w", err)
}
@ -76,7 +69,7 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow
defer listen.Close()
localPort := listen.Addr().(*net.TCPAddr).Port
log.Println("Fetching SSH Details...")
a.logger.Println("Fetching SSH Details...")
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)

View file

@ -4,22 +4,45 @@ import (
"errors"
"fmt"
"io"
"net/http"
"os"
"github.com/github/ghcs/cmd/ghcs"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/spf13/cobra"
)
func main() {
rootCmd := ghcs.NewRootCmd()
token := os.Getenv("GITHUB_TOKEN")
rootCmd := ghcs.NewRootCmd(ghcs.NewApp(
output.NewLogger(os.Stdout, os.Stderr, false),
api.New(token, http.DefaultClient),
))
// Require GITHUB_TOKEN through a Cobra pre-run hook so that Cobra's help system for commands can still
// function without the token set.
oldPreRun := rootCmd.PersistentPreRunE
rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
if token == "" {
return errTokenMissing
}
if oldPreRun != nil {
return oldPreRun(cmd, args)
}
return nil
}
if cmd, err := rootCmd.ExecuteC(); err != nil {
explainError(os.Stderr, err, cmd)
os.Exit(1)
}
}
var errTokenMissing = errors.New("GITHUB_TOKEN is missing")
func explainError(w io.Writer, err error, cmd *cobra.Command) {
if errors.Is(err, ghcs.ErrTokenMissing) {
if errors.Is(err, errTokenMissing) {
fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo")
fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.")
return

View file

@ -16,21 +16,42 @@ import (
//
// // make and configure a mocked apiClient
// mockedapiClient := &apiClientMock{
// AuthorizedKeysFunc: func(ctx context.Context, user string) ([]byte, error) {
// panic("mock out the AuthorizedKeys method")
// },
// CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
// panic("mock out the CreateCodespace method")
// },
// DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error {
// panic("mock out the DeleteCodespace method")
// },
// GetCodespaceFunc: func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) {
// panic("mock out the GetCodespace method")
// },
// GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) {
// panic("mock out the GetCodespaceRegionLocation method")
// },
// GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) {
// panic("mock out the GetCodespaceRepositoryContents method")
// },
// GetCodespaceTokenFunc: func(ctx context.Context, user string, name string) (string, error) {
// panic("mock out the GetCodespaceToken method")
// },
// GetCodespacesSKUsFunc: func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) {
// panic("mock out the GetCodespacesSKUs method")
// },
// GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
// panic("mock out the GetRepository method")
// },
// GetUserFunc: func(ctx context.Context) (*api.User, error) {
// panic("mock out the GetUser method")
// },
// ListCodespacesFunc: func(ctx context.Context, user string) ([]*api.Codespace, error) {
// panic("mock out the ListCodespaces method")
// },
// StartCodespaceFunc: func(ctx context.Context, token string, codespace *api.Codespace) error {
// panic("mock out the StartCodespace method")
// },
// }
//
// // use mockedapiClient in code that requires apiClient
@ -38,23 +59,58 @@ import (
//
// }
type apiClientMock struct {
// AuthorizedKeysFunc mocks the AuthorizedKeys method.
AuthorizedKeysFunc func(ctx context.Context, user string) ([]byte, error)
// CreateCodespaceFunc mocks the CreateCodespace method.
CreateCodespaceFunc func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
// DeleteCodespaceFunc mocks the DeleteCodespace method.
DeleteCodespaceFunc func(ctx context.Context, user string, name string) error
// GetCodespaceFunc mocks the GetCodespace method.
GetCodespaceFunc func(ctx context.Context, token string, user string, name string) (*api.Codespace, error)
// GetCodespaceRegionLocationFunc mocks the GetCodespaceRegionLocation method.
GetCodespaceRegionLocationFunc func(ctx context.Context) (string, error)
// GetCodespaceRepositoryContentsFunc mocks the GetCodespaceRepositoryContents method.
GetCodespaceRepositoryContentsFunc func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
// GetCodespaceTokenFunc mocks the GetCodespaceToken method.
GetCodespaceTokenFunc func(ctx context.Context, user string, name string) (string, error)
// GetCodespacesSKUsFunc mocks the GetCodespacesSKUs method.
GetCodespacesSKUsFunc func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error)
// GetRepositoryFunc mocks the GetRepository method.
GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error)
// GetUserFunc mocks the GetUser method.
GetUserFunc func(ctx context.Context) (*api.User, error)
// ListCodespacesFunc mocks the ListCodespaces method.
ListCodespacesFunc func(ctx context.Context, user string) ([]*api.Codespace, error)
// StartCodespaceFunc mocks the StartCodespace method.
StartCodespaceFunc func(ctx context.Context, token string, codespace *api.Codespace) error
// calls tracks calls to the methods.
calls struct {
// AuthorizedKeys holds details about calls to the AuthorizedKeys method.
AuthorizedKeys []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User string
}
// CreateCodespace holds details about calls to the CreateCodespace method.
CreateCodespace []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Params is the params argument value.
Params *api.CreateCodespaceParams
}
// DeleteCodespace holds details about calls to the DeleteCodespace method.
DeleteCodespace []struct {
// Ctx is the ctx argument value.
@ -75,6 +131,20 @@ type apiClientMock struct {
// Name is the name argument value.
Name string
}
// GetCodespaceRegionLocation holds details about calls to the GetCodespaceRegionLocation method.
GetCodespaceRegionLocation []struct {
// Ctx is the ctx argument value.
Ctx context.Context
}
// GetCodespaceRepositoryContents holds details about calls to the GetCodespaceRepositoryContents method.
GetCodespaceRepositoryContents []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Codespace is the codespace argument value.
Codespace *api.Codespace
// Path is the path argument value.
Path string
}
// GetCodespaceToken holds details about calls to the GetCodespaceToken method.
GetCodespaceToken []struct {
// Ctx is the ctx argument value.
@ -84,6 +154,26 @@ type apiClientMock struct {
// Name is the name argument value.
Name string
}
// GetCodespacesSKUs holds details about calls to the GetCodespacesSKUs method.
GetCodespacesSKUs []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User *api.User
// Repository is the repository argument value.
Repository *api.Repository
// Branch is the branch argument value.
Branch string
// Location is the location argument value.
Location string
}
// GetRepository holds details about calls to the GetRepository method.
GetRepository []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Nwo is the nwo argument value.
Nwo string
}
// GetUser holds details about calls to the GetUser method.
GetUser []struct {
// Ctx is the ctx argument value.
@ -96,12 +186,98 @@ type apiClientMock struct {
// User is the user argument value.
User string
}
// StartCodespace holds details about calls to the StartCodespace method.
StartCodespace []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Token is the token argument value.
Token string
// Codespace is the codespace argument value.
Codespace *api.Codespace
}
}
lockDeleteCodespace sync.RWMutex
lockGetCodespace sync.RWMutex
lockGetCodespaceToken sync.RWMutex
lockGetUser sync.RWMutex
lockListCodespaces sync.RWMutex
lockAuthorizedKeys sync.RWMutex
lockCreateCodespace sync.RWMutex
lockDeleteCodespace sync.RWMutex
lockGetCodespace sync.RWMutex
lockGetCodespaceRegionLocation sync.RWMutex
lockGetCodespaceRepositoryContents sync.RWMutex
lockGetCodespaceToken sync.RWMutex
lockGetCodespacesSKUs sync.RWMutex
lockGetRepository sync.RWMutex
lockGetUser sync.RWMutex
lockListCodespaces sync.RWMutex
lockStartCodespace sync.RWMutex
}
// AuthorizedKeys calls AuthorizedKeysFunc.
func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
if mock.AuthorizedKeysFunc == nil {
panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called")
}
callInfo := struct {
Ctx context.Context
User string
}{
Ctx: ctx,
User: user,
}
mock.lockAuthorizedKeys.Lock()
mock.calls.AuthorizedKeys = append(mock.calls.AuthorizedKeys, callInfo)
mock.lockAuthorizedKeys.Unlock()
return mock.AuthorizedKeysFunc(ctx, user)
}
// AuthorizedKeysCalls gets all the calls that were made to AuthorizedKeys.
// Check the length with:
// len(mockedapiClient.AuthorizedKeysCalls())
func (mock *apiClientMock) AuthorizedKeysCalls() []struct {
Ctx context.Context
User string
} {
var calls []struct {
Ctx context.Context
User string
}
mock.lockAuthorizedKeys.RLock()
calls = mock.calls.AuthorizedKeys
mock.lockAuthorizedKeys.RUnlock()
return calls
}
// CreateCodespace calls CreateCodespaceFunc.
func (mock *apiClientMock) CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
if mock.CreateCodespaceFunc == nil {
panic("apiClientMock.CreateCodespaceFunc: method is nil but apiClient.CreateCodespace was just called")
}
callInfo := struct {
Ctx context.Context
Params *api.CreateCodespaceParams
}{
Ctx: ctx,
Params: params,
}
mock.lockCreateCodespace.Lock()
mock.calls.CreateCodespace = append(mock.calls.CreateCodespace, callInfo)
mock.lockCreateCodespace.Unlock()
return mock.CreateCodespaceFunc(ctx, params)
}
// CreateCodespaceCalls gets all the calls that were made to CreateCodespace.
// Check the length with:
// len(mockedapiClient.CreateCodespaceCalls())
func (mock *apiClientMock) CreateCodespaceCalls() []struct {
Ctx context.Context
Params *api.CreateCodespaceParams
} {
var calls []struct {
Ctx context.Context
Params *api.CreateCodespaceParams
}
mock.lockCreateCodespace.RLock()
calls = mock.calls.CreateCodespace
mock.lockCreateCodespace.RUnlock()
return calls
}
// DeleteCodespace calls DeleteCodespaceFunc.
@ -186,6 +362,76 @@ func (mock *apiClientMock) GetCodespaceCalls() []struct {
return calls
}
// GetCodespaceRegionLocation calls GetCodespaceRegionLocationFunc.
func (mock *apiClientMock) GetCodespaceRegionLocation(ctx context.Context) (string, error) {
if mock.GetCodespaceRegionLocationFunc == nil {
panic("apiClientMock.GetCodespaceRegionLocationFunc: method is nil but apiClient.GetCodespaceRegionLocation was just called")
}
callInfo := struct {
Ctx context.Context
}{
Ctx: ctx,
}
mock.lockGetCodespaceRegionLocation.Lock()
mock.calls.GetCodespaceRegionLocation = append(mock.calls.GetCodespaceRegionLocation, callInfo)
mock.lockGetCodespaceRegionLocation.Unlock()
return mock.GetCodespaceRegionLocationFunc(ctx)
}
// GetCodespaceRegionLocationCalls gets all the calls that were made to GetCodespaceRegionLocation.
// Check the length with:
// len(mockedapiClient.GetCodespaceRegionLocationCalls())
func (mock *apiClientMock) GetCodespaceRegionLocationCalls() []struct {
Ctx context.Context
} {
var calls []struct {
Ctx context.Context
}
mock.lockGetCodespaceRegionLocation.RLock()
calls = mock.calls.GetCodespaceRegionLocation
mock.lockGetCodespaceRegionLocation.RUnlock()
return calls
}
// GetCodespaceRepositoryContents calls GetCodespaceRepositoryContentsFunc.
func (mock *apiClientMock) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) {
if mock.GetCodespaceRepositoryContentsFunc == nil {
panic("apiClientMock.GetCodespaceRepositoryContentsFunc: method is nil but apiClient.GetCodespaceRepositoryContents was just called")
}
callInfo := struct {
Ctx context.Context
Codespace *api.Codespace
Path string
}{
Ctx: ctx,
Codespace: codespace,
Path: path,
}
mock.lockGetCodespaceRepositoryContents.Lock()
mock.calls.GetCodespaceRepositoryContents = append(mock.calls.GetCodespaceRepositoryContents, callInfo)
mock.lockGetCodespaceRepositoryContents.Unlock()
return mock.GetCodespaceRepositoryContentsFunc(ctx, codespace, path)
}
// GetCodespaceRepositoryContentsCalls gets all the calls that were made to GetCodespaceRepositoryContents.
// Check the length with:
// len(mockedapiClient.GetCodespaceRepositoryContentsCalls())
func (mock *apiClientMock) GetCodespaceRepositoryContentsCalls() []struct {
Ctx context.Context
Codespace *api.Codespace
Path string
} {
var calls []struct {
Ctx context.Context
Codespace *api.Codespace
Path string
}
mock.lockGetCodespaceRepositoryContents.RLock()
calls = mock.calls.GetCodespaceRepositoryContents
mock.lockGetCodespaceRepositoryContents.RUnlock()
return calls
}
// GetCodespaceToken calls GetCodespaceTokenFunc.
func (mock *apiClientMock) GetCodespaceToken(ctx context.Context, user string, name string) (string, error) {
if mock.GetCodespaceTokenFunc == nil {
@ -225,6 +471,88 @@ func (mock *apiClientMock) GetCodespaceTokenCalls() []struct {
return calls
}
// GetCodespacesSKUs calls GetCodespacesSKUsFunc.
func (mock *apiClientMock) GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) {
if mock.GetCodespacesSKUsFunc == nil {
panic("apiClientMock.GetCodespacesSKUsFunc: method is nil but apiClient.GetCodespacesSKUs was just called")
}
callInfo := struct {
Ctx context.Context
User *api.User
Repository *api.Repository
Branch string
Location string
}{
Ctx: ctx,
User: user,
Repository: repository,
Branch: branch,
Location: location,
}
mock.lockGetCodespacesSKUs.Lock()
mock.calls.GetCodespacesSKUs = append(mock.calls.GetCodespacesSKUs, callInfo)
mock.lockGetCodespacesSKUs.Unlock()
return mock.GetCodespacesSKUsFunc(ctx, user, repository, branch, location)
}
// GetCodespacesSKUsCalls gets all the calls that were made to GetCodespacesSKUs.
// Check the length with:
// len(mockedapiClient.GetCodespacesSKUsCalls())
func (mock *apiClientMock) GetCodespacesSKUsCalls() []struct {
Ctx context.Context
User *api.User
Repository *api.Repository
Branch string
Location string
} {
var calls []struct {
Ctx context.Context
User *api.User
Repository *api.Repository
Branch string
Location string
}
mock.lockGetCodespacesSKUs.RLock()
calls = mock.calls.GetCodespacesSKUs
mock.lockGetCodespacesSKUs.RUnlock()
return calls
}
// GetRepository calls GetRepositoryFunc.
func (mock *apiClientMock) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) {
if mock.GetRepositoryFunc == nil {
panic("apiClientMock.GetRepositoryFunc: method is nil but apiClient.GetRepository was just called")
}
callInfo := struct {
Ctx context.Context
Nwo string
}{
Ctx: ctx,
Nwo: nwo,
}
mock.lockGetRepository.Lock()
mock.calls.GetRepository = append(mock.calls.GetRepository, callInfo)
mock.lockGetRepository.Unlock()
return mock.GetRepositoryFunc(ctx, nwo)
}
// GetRepositoryCalls gets all the calls that were made to GetRepository.
// Check the length with:
// len(mockedapiClient.GetRepositoryCalls())
func (mock *apiClientMock) GetRepositoryCalls() []struct {
Ctx context.Context
Nwo string
} {
var calls []struct {
Ctx context.Context
Nwo string
}
mock.lockGetRepository.RLock()
calls = mock.calls.GetRepository
mock.lockGetRepository.RUnlock()
return calls
}
// GetUser calls GetUserFunc.
func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) {
if mock.GetUserFunc == nil {
@ -290,3 +618,42 @@ func (mock *apiClientMock) ListCodespacesCalls() []struct {
mock.lockListCodespaces.RUnlock()
return calls
}
// StartCodespace calls StartCodespaceFunc.
func (mock *apiClientMock) StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error {
if mock.StartCodespaceFunc == nil {
panic("apiClientMock.StartCodespaceFunc: method is nil but apiClient.StartCodespace was just called")
}
callInfo := struct {
Ctx context.Context
Token string
Codespace *api.Codespace
}{
Ctx: ctx,
Token: token,
Codespace: codespace,
}
mock.lockStartCodespace.Lock()
mock.calls.StartCodespace = append(mock.calls.StartCodespace, callInfo)
mock.lockStartCodespace.Unlock()
return mock.StartCodespaceFunc(ctx, token, codespace)
}
// StartCodespaceCalls gets all the calls that were made to StartCodespace.
// Check the length with:
// len(mockedapiClient.StartCodespaceCalls())
func (mock *apiClientMock) StartCodespaceCalls() []struct {
Ctx context.Context
Token string
Codespace *api.Codespace
} {
var calls []struct {
Ctx context.Context
Token string
Codespace *api.Codespace
}
mock.lockStartCodespace.RLock()
calls = mock.calls.StartCodespace
mock.lockStartCodespace.RUnlock()
return calls
}

View file

@ -22,7 +22,7 @@ import (
// newPortsCmd returns a Cobra "ports" command that displays a table of available ports,
// according to the specified flags.
func newPortsCmd() *cobra.Command {
func newPortsCmd(app *App) *cobra.Command {
var (
codespace string
asJSON bool
@ -33,31 +33,28 @@ func newPortsCmd() *cobra.Command {
Short: "List ports in a codespace",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return ports(codespace, asJSON)
return app.ListPorts(cmd.Context(), codespace, asJSON)
},
}
portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace")
portsCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON")
portsCmd.AddCommand(newPortsPublicCmd())
portsCmd.AddCommand(newPortsPrivateCmd())
portsCmd.AddCommand(newPortsForwardCmd())
portsCmd.AddCommand(newPortsPublicCmd(app))
portsCmd.AddCommand(newPortsPrivateCmd(app))
portsCmd.AddCommand(newPortsForwardCmd(app))
return portsCmd
}
func ports(codespaceName string, asJSON bool) (err error) {
apiClient := api.New(os.Getenv("GITHUB_TOKEN"))
ctx := context.Background()
log := output.NewLogger(os.Stdout, os.Stderr, asJSON)
user, err := apiClient.GetUser(ctx)
// ListPorts lists known ports in a codespace.
func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool) (err error) {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName)
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
// TODO(josebalius): remove special handling of this error here and it other places
if err == errNoCodespaces {
@ -66,15 +63,15 @@ func ports(codespaceName string, asJSON bool) (err error) {
return fmt.Errorf("error choosing codespace: %w", err)
}
devContainerCh := getDevContainer(ctx, apiClient, codespace)
devContainerCh := getDevContainer(ctx, a.apiClient, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
log.Println("Loading ports...")
a.logger.Println("Loading ports...")
ports, err := session.GetSharedServers(ctx)
if err != nil {
return fmt.Errorf("error getting ports of shared servers: %w", err)
@ -83,7 +80,7 @@ func ports(codespaceName string, asJSON bool) (err error) {
devContainerResult := <-devContainerCh
if devContainerResult.err != nil {
// Warn about failure to read the devcontainer file. Not a ghcs command error.
_, _ = log.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error())
_, _ = a.logger.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error())
}
table := output.NewTable(os.Stdout, asJSON)
@ -122,7 +119,7 @@ type portAttribute struct {
Label string `json:"label"`
}
func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Codespace) <-chan devContainerResult {
func getDevContainer(ctx context.Context, apiClient apiClient, codespace *api.Codespace) <-chan devContainerResult {
ch := make(chan devContainerResult, 1)
go func() {
contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json")
@ -154,7 +151,7 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod
}
// newPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public.
func newPortsPublicCmd() *cobra.Command {
func newPortsPublicCmd(app *App) *cobra.Command {
return &cobra.Command{
Use: "public <port>",
Short: "Mark port as public",
@ -168,14 +165,13 @@ func newPortsPublicCmd() *cobra.Command {
return fmt.Errorf("get codespace flag: %w", err)
}
log := output.NewLogger(os.Stdout, os.Stderr, false)
return updatePortVisibility(log, codespace, args[0], true)
return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], true)
},
}
}
// newPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private.
func newPortsPrivateCmd() *cobra.Command {
func newPortsPrivateCmd(app *App) *cobra.Command {
return &cobra.Command{
Use: "private <port>",
Short: "Mark port as private",
@ -189,22 +185,18 @@ func newPortsPrivateCmd() *cobra.Command {
return fmt.Errorf("get codespace flag: %w", err)
}
log := output.NewLogger(os.Stdout, os.Stderr, false)
return updatePortVisibility(log, codespace, args[0], false)
return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], false)
},
}
}
func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) (err error) {
ctx := context.Background()
apiClient := api.New(GithubToken)
user, err := apiClient.GetUser(ctx)
func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePort string, public bool) (err error) {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName)
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
if err == errNoCodespaces {
return err
@ -212,7 +204,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string,
return fmt.Errorf("error getting codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
@ -231,14 +223,14 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string,
if !public {
state = "PRIVATE"
}
log.Printf("Port %s is now %s.\n", sourcePort, state)
a.logger.Printf("Port %s is now %s.\n", sourcePort, state)
return nil
}
// NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of
// port pairs from the codespace to localhost.
func newPortsForwardCmd() *cobra.Command {
func newPortsForwardCmd(app *App) *cobra.Command {
return &cobra.Command{
Use: "forward <remote-port>:<local-port>...",
Short: "Forward ports",
@ -252,27 +244,23 @@ func newPortsForwardCmd() *cobra.Command {
return fmt.Errorf("get codespace flag: %w", err)
}
log := output.NewLogger(os.Stdout, os.Stderr, false)
return forwardPorts(log, codespace, args)
return app.ForwardPorts(cmd.Context(), codespace, args)
},
}
}
func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err error) {
ctx := context.Background()
apiClient := api.New(GithubToken)
func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []string) (err error) {
portPairs, err := getPortPairs(ports)
if err != nil {
return fmt.Errorf("get port pairs: %w", err)
}
user, err := apiClient.GetUser(ctx)
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName)
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
if err == errNoCodespaces {
return err
@ -280,7 +268,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err
return fmt.Errorf("error getting codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
@ -297,7 +285,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err
return err
}
defer listen.Close()
log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)
a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)
name := fmt.Sprintf("share-%d", pair.remote)
fwd := liveshare.NewPortForwarder(session, name, pair.remote)
return fwd.ForwardToListener(ctx, listen) // error always non-nil

View file

@ -1,10 +1,8 @@
package ghcs
import (
"errors"
"fmt"
"log"
"os"
"strconv"
"strings"
@ -15,10 +13,7 @@ import (
var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number.
// GithubToken is a temporary stopgap to make the token configurable by apps that import this package
var GithubToken = os.Getenv("GITHUB_TOKEN")
func NewRootCmd() *cobra.Command {
func NewRootCmd(app *App) *cobra.Command {
var lightstep string
root := &cobra.Command{
@ -32,28 +27,23 @@ token to access the GitHub API with.`,
Version: version,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
if os.Getenv("GITHUB_TOKEN") == "" {
return ErrTokenMissing
}
return initLightstep(lightstep)
},
}
root.PersistentFlags().StringVar(&lightstep, "lightstep", "", "Lightstep tracing endpoint (service:token@host:port)")
root.AddCommand(newCodeCmd())
root.AddCommand(newCreateCmd())
root.AddCommand(newDeleteCmd())
root.AddCommand(newListCmd())
root.AddCommand(newLogsCmd())
root.AddCommand(newPortsCmd())
root.AddCommand(newSSHCmd())
root.AddCommand(newCodeCmd(app))
root.AddCommand(newCreateCmd(app))
root.AddCommand(newDeleteCmd(app))
root.AddCommand(newListCmd(app))
root.AddCommand(newLogsCmd(app))
root.AddCommand(newPortsCmd(app))
root.AddCommand(newSSHCmd(app))
return root
}
var ErrTokenMissing = errors.New("GITHUB_TOKEN is missing")
// initLightstep parses the --lightstep=service:token@host:port flag and
// enables tracing if non-empty.
func initLightstep(config string) error {

View file

@ -4,16 +4,13 @@ import (
"context"
"fmt"
"net"
"os"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/github/ghcs/internal/codespaces"
"github.com/github/ghcs/internal/liveshare"
"github.com/spf13/cobra"
)
func newSSHCmd() *cobra.Command {
func newSSHCmd(app *App) *cobra.Command {
var sshProfile, codespaceName string
var sshServerPort int
@ -21,7 +18,7 @@ func newSSHCmd() *cobra.Command {
Use: "ssh [flags] [--] [ssh-flags] [command]",
Short: "SSH into a codespace",
RunE: func(cmd *cobra.Command, args []string) error {
return ssh(context.Background(), args, sshProfile, codespaceName, sshServerPort)
return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort)
},
}
@ -32,30 +29,28 @@ func newSSHCmd() *cobra.Command {
return sshCmd
}
func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) {
// SSH opens an ssh session or runs an ssh command in a codespace.
func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) {
// Ensure all child tasks (e.g. port forwarding) terminate before return.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
apiClient := api.New(GithubToken)
log := output.NewLogger(os.Stdout, os.Stderr, false)
user, err := apiClient.GetUser(ctx)
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
authkeys := make(chan error, 1)
go func() {
authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login)
authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login)
}()
codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName)
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
return fmt.Errorf("get or choose codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
@ -65,7 +60,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string
return err
}
log.Println("Fetching SSH Details...")
a.logger.Println("Fetching SSH Details...")
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)
@ -86,7 +81,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string
connectDestination = fmt.Sprintf("%s@localhost", sshUser)
}
log.Println("Ready...")
a.logger.Println("Ready...")
tunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
@ -95,7 +90,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string
shellClosed := make(chan error, 1)
go func() {
shellClosed <- codespaces.Shell(ctx, log, sshArgs, localSSHServerPort, connectDestination, usingCustomPort)
shellClosed <- codespaces.Shell(ctx, a.logger, sshArgs, localSSHServerPort, connectDestination, usingCustomPort)
}()
select {

View file

@ -45,14 +45,18 @@ const githubAPI = "https://api.github.com"
type API struct {
token string
client *http.Client
client httpClient
githubAPI string
}
func New(token string) *API {
type httpClient interface {
Do(req *http.Request) (*http.Response, error)
}
func New(token string, httpClient httpClient) *API {
return &API{
token: token,
client: &http.Client{},
client: httpClient,
githubAPI: githubAPI,
}
}
@ -272,6 +276,7 @@ func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string)
return nil, fmt.Errorf("error creating request: %w", err)
}
// TODO: use a.setHeaders()
req.Header.Set("Authorization", "Bearer "+token)
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*")
if err != nil {
@ -306,6 +311,7 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes
return fmt.Errorf("error creating request: %w", err)
}
// TODO: use a.setHeaders()
req.Header.Set("Authorization", "Bearer "+token)
resp, err := a.do(ctx, req, "/vscs_internal/proxy/environments/*/start")
if err != nil {
@ -420,14 +426,9 @@ type CreateCodespaceParams struct {
Branch, Machine, Location string
}
type logger interface {
Print(v ...interface{}) (int, error)
Println(v ...interface{}) (int, error)
}
// CreateCodespace creates a codespace with the given parameters and returns a non-nil error if it
// fails to create.
func (a *API) CreateCodespace(ctx context.Context, log logger, params *CreateCodespaceParams) (*Codespace, error) {
func (a *API) CreateCodespace(ctx context.Context, params *CreateCodespaceParams) (*Codespace, error) {
codespace, err := a.startCreate(
ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location,
)
@ -449,7 +450,6 @@ func (a *API) CreateCodespace(ctx context.Context, log logger, params *CreateCod
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
log.Print(".")
token, err := a.GetCodespaceToken(ctx, params.User, codespace.Name)
if err != nil {
if err == ErrNotProvisioned {
@ -532,6 +532,7 @@ func (a *API) DeleteCodespace(ctx context.Context, user string, codespaceName st
return fmt.Errorf("error creating request: %w", err)
}
// TODO: use a.setHeaders()
req.Header.Set("Authorization", "Bearer "+token)
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*")
if err != nil {
@ -631,6 +632,8 @@ func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http
}
func (a *API) setHeaders(req *http.Request) {
req.Header.Set("Authorization", "Bearer "+a.token)
if a.token != "" {
req.Header.Set("Authorization", "Bearer "+a.token)
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
}

View file

@ -23,9 +23,15 @@ func connectionReady(codespace *api.Codespace) bool {
codespace.Environment.State == api.CodespaceEnvironmentStateAvailable
}
type apiClient interface {
GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error)
GetCodespaceToken(ctx context.Context, user, codespace string) (string, error)
StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error
}
// ConnectToLiveshare waits for a Codespace to become running,
// and connects to it using a Live Share session.
func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) {
func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) {
var startedCodespace bool
if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable {
startedCodespace = true

View file

@ -36,7 +36,7 @@ type PostCreateState struct {
// PollPostCreateStates watches for state changes in a codespace,
// and calls the supplied poller for each batch of state changes.
// It runs until it encounters an error, including cancellation of the context.
func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) {
func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) {
token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name)
if err != nil {
return fmt.Errorf("getting codespace token: %w", err)