Merge pull request #4384 from cli/import-codespaces

Import codespaces
This commit is contained in:
Jose Garcia 2021-09-30 09:44:29 -04:00 committed by GitHub
commit 877ad22da6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
38 changed files with 5227 additions and 14 deletions

65
cmd/ghcs/code.go Normal file
View file

@ -0,0 +1,65 @@
package ghcs
import (
"context"
"fmt"
"net/url"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
)
func newCodeCmd(app *App) *cobra.Command {
var (
codespace string
useInsiders bool
)
codeCmd := &cobra.Command{
Use: "code",
Short: "Open a codespace in VS Code",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return app.VSCode(cmd.Context(), codespace, useInsiders)
},
}
codeCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace")
codeCmd.Flags().BoolVar(&useInsiders, "insiders", false, "Use the insiders version of VS Code")
return codeCmd
}
// VSCode opens a codespace in the local VS VSCode application.
func (a *App) VSCode(ctx context.Context, codespaceName string, useInsiders bool) error {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
if codespaceName == "" {
codespace, err := chooseCodespace(ctx, a.apiClient, user)
if err != nil {
if err == errNoCodespaces {
return err
}
return fmt.Errorf("error choosing codespace: %w", err)
}
codespaceName = codespace.Name
}
url := vscodeProtocolURL(codespaceName, useInsiders)
if err := open.Run(url); err != nil {
return fmt.Errorf("error opening vscode URL %s: %s. (Is VS Code installed?)", url, err)
}
return nil
}
func vscodeProtocolURL(codespaceName string, useInsiders bool) string {
application := "vscode"
if useInsiders {
application = "vscode-insiders"
}
return fmt.Sprintf("%s://github.codespaces/connect?name=%s", application, url.QueryEscape(codespaceName))
}

185
cmd/ghcs/common.go Normal file
View file

@ -0,0 +1,185 @@
package ghcs
// This file defines functions common to the entire ghcs command set.
import (
"context"
"errors"
"fmt"
"io"
"os"
"sort"
"github.com/AlecAivazis/survey/v2"
"github.com/AlecAivazis/survey/v2/terminal"
"github.com/cli/cli/v2/cmd/ghcs/output"
"github.com/cli/cli/v2/internal/api"
"github.com/spf13/cobra"
"golang.org/x/term"
)
type App struct {
apiClient apiClient
logger *output.Logger
}
func NewApp(logger *output.Logger, apiClient apiClient) *App {
return &App{
apiClient: apiClient,
logger: logger,
}
}
//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient
type apiClient interface {
GetUser(ctx context.Context) (*api.User, error)
GetCodespaceToken(ctx context.Context, user, name string) (string, error)
GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error)
ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error)
DeleteCodespace(ctx context.Context, user, name string) error
StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error
CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
GetRepository(ctx context.Context, nwo string) (*api.Repository, error)
AuthorizedKeys(ctx context.Context, user string) ([]byte, error)
GetCodespaceRegionLocation(ctx context.Context) (string, error)
GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch, location string) ([]*api.SKU, error)
GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
}
var errNoCodespaces = errors.New("you have no codespaces")
func chooseCodespace(ctx context.Context, apiClient apiClient, user *api.User) (*api.Codespace, error) {
codespaces, err := apiClient.ListCodespaces(ctx, user.Login)
if err != nil {
return nil, fmt.Errorf("error getting codespaces: %w", err)
}
return chooseCodespaceFromList(ctx, codespaces)
}
func chooseCodespaceFromList(ctx context.Context, codespaces []*api.Codespace) (*api.Codespace, error) {
if len(codespaces) == 0 {
return nil, errNoCodespaces
}
sort.Slice(codespaces, func(i, j int) bool {
return codespaces[i].CreatedAt > codespaces[j].CreatedAt
})
codespacesByName := make(map[string]*api.Codespace)
codespacesNames := make([]string, 0, len(codespaces))
for _, codespace := range codespaces {
codespacesByName[codespace.Name] = codespace
codespacesNames = append(codespacesNames, codespace.Name)
}
sshSurvey := []*survey.Question{
{
Name: "codespace",
Prompt: &survey.Select{
Message: "Choose codespace:",
Options: codespacesNames,
Default: codespacesNames[0],
},
Validate: survey.Required,
},
}
var answers struct {
Codespace string
}
if err := ask(sshSurvey, &answers); err != nil {
return nil, fmt.Errorf("error getting answers: %w", err)
}
codespace := codespacesByName[answers.Codespace]
return codespace, nil
}
// getOrChooseCodespace prompts the user to choose a codespace if the codespaceName is empty.
// It then fetches the codespace token and the codespace record.
func getOrChooseCodespace(ctx context.Context, apiClient apiClient, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) {
if codespaceName == "" {
codespace, err = chooseCodespace(ctx, apiClient, user)
if err != nil {
if err == errNoCodespaces {
return nil, "", err
}
return nil, "", fmt.Errorf("choosing codespace: %w", err)
}
codespaceName = codespace.Name
token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName)
if err != nil {
return nil, "", fmt.Errorf("getting codespace token: %w", err)
}
} else {
token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName)
if err != nil {
return nil, "", fmt.Errorf("getting codespace token for given codespace: %w", err)
}
codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName)
if err != nil {
return nil, "", fmt.Errorf("getting full codespace details: %w", err)
}
}
return codespace, token, nil
}
func safeClose(closer io.Closer, err *error) {
if closeErr := closer.Close(); *err == nil {
*err = closeErr
}
}
// hasTTY indicates whether the process connected to a terminal.
// It is not portable to assume stdin/stdout are fds 0 and 1.
var hasTTY = term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
// ask asks survey questions on the terminal, using standard options.
// It fails unless hasTTY, but ideally callers should avoid calling it in that case.
func ask(qs []*survey.Question, response interface{}) error {
if !hasTTY {
return fmt.Errorf("no terminal")
}
err := survey.Ask(qs, response, survey.WithShowCursor(true))
// The survey package temporarily clears the terminal's ISIG mode bit
// (see tcsetattr(3)) so the QUIT button (Ctrl-C) is reported as
// ASCII \x03 (ETX) instead of delivering SIGINT to the application.
// So we have to serve ourselves the SIGINT.
//
// https://github.com/AlecAivazis/survey/#why-isnt-ctrl-c-working
if err == terminal.InterruptErr {
self, _ := os.FindProcess(os.Getpid())
_ = self.Signal(os.Interrupt) // assumes POSIX
// Suspend the goroutine, to avoid a race between
// return from main and async delivery of INT signal.
select {}
}
return err
}
// checkAuthorizedKeys reports an error if the user has not registered any SSH keys;
// see https://github.com/cli/cli/v2/issues/166#issuecomment-921769703.
// The check is not required for security but it improves the error message.
func checkAuthorizedKeys(ctx context.Context, client apiClient, user string) error {
keys, err := client.AuthorizedKeys(ctx, user)
if err != nil {
return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err)
}
if len(keys) == 0 {
return fmt.Errorf("user %s has no GitHub-authorized SSH keys", user)
}
return nil // success
}
var ErrTooManyArgs = errors.New("the command accepts no arguments")
func noArgsConstraint(cmd *cobra.Command, args []string) error {
if len(args) > 0 {
return ErrTooManyArgs
}
return nil
}

297
cmd/ghcs/create.go Normal file
View file

@ -0,0 +1,297 @@
package ghcs
import (
"context"
"errors"
"fmt"
"os"
"strings"
"github.com/AlecAivazis/survey/v2"
"github.com/cli/cli/v2/cmd/ghcs/output"
"github.com/cli/cli/v2/internal/api"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/fatih/camelcase"
"github.com/spf13/cobra"
)
type createOptions struct {
repo string
branch string
machine string
showStatus bool
}
func newCreateCmd(app *App) *cobra.Command {
opts := createOptions{}
createCmd := &cobra.Command{
Use: "create",
Short: "Create a codespace",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return app.Create(cmd.Context(), opts)
},
}
createCmd.Flags().StringVarP(&opts.repo, "repo", "r", "", "repository name with owner: user/repo")
createCmd.Flags().StringVarP(&opts.branch, "branch", "b", "", "repository branch")
createCmd.Flags().StringVarP(&opts.machine, "machine", "m", "", "hardware specifications for the VM")
createCmd.Flags().BoolVarP(&opts.showStatus, "status", "s", false, "show status of post-create command and dotfiles")
return createCmd
}
// Create creates a new Codespace
func (a *App) Create(ctx context.Context, opts createOptions) error {
locationCh := getLocation(ctx, a.apiClient)
userCh := getUser(ctx, a.apiClient)
repo, err := getRepoName(opts.repo)
if err != nil {
return fmt.Errorf("error getting repository name: %w", err)
}
branch, err := getBranchName(opts.branch)
if err != nil {
return fmt.Errorf("error getting branch name: %w", err)
}
repository, err := a.apiClient.GetRepository(ctx, repo)
if err != nil {
return fmt.Errorf("error getting repository: %w", err)
}
locationResult := <-locationCh
if locationResult.Err != nil {
return fmt.Errorf("error getting codespace region location: %w", locationResult.Err)
}
userResult := <-userCh
if userResult.Err != nil {
return fmt.Errorf("error getting codespace user: %w", userResult.Err)
}
machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, a.apiClient)
if err != nil {
return fmt.Errorf("error getting machine type: %w", err)
}
if machine == "" {
return errors.New("there are no available machine types for this repository")
}
a.logger.Print("Creating your codespace...")
codespace, err := a.apiClient.CreateCodespace(ctx, &api.CreateCodespaceParams{
User: userResult.User.Login,
RepositoryID: repository.ID,
Branch: branch,
Machine: machine,
Location: locationResult.Location,
})
a.logger.Print("\n")
if err != nil {
return fmt.Errorf("error creating codespace: %w", err)
}
if opts.showStatus {
if err := showStatus(ctx, a.logger, a.apiClient, userResult.User, codespace); err != nil {
return fmt.Errorf("show status: %w", err)
}
}
a.logger.Printf("Codespace created: ")
fmt.Fprintln(os.Stdout, codespace.Name)
return nil
}
// showStatus polls the codespace for a list of post create states and their status. It will keep polling
// until all states have finished. Once all states have finished, we poll once more to check if any new
// states have been introduced and stop polling otherwise.
func showStatus(ctx context.Context, log *output.Logger, apiClient apiClient, user *api.User, codespace *api.Codespace) error {
var lastState codespaces.PostCreateState
var breakNextState bool
finishedStates := make(map[string]bool)
ctx, stopPolling := context.WithCancel(ctx)
defer stopPolling()
poller := func(states []codespaces.PostCreateState) {
var inProgress bool
for _, state := range states {
if _, found := finishedStates[state.Name]; found {
continue // skip this state as we've processed it already
}
if state.Name != lastState.Name {
log.Print(state.Name)
if state.Status == codespaces.PostCreateStateRunning {
inProgress = true
lastState = state
log.Print("...")
break
}
finishedStates[state.Name] = true
log.Println("..." + state.Status)
} else {
if state.Status == codespaces.PostCreateStateRunning {
inProgress = true
log.Print(".")
break
}
finishedStates[state.Name] = true
log.Println(state.Status)
lastState = codespaces.PostCreateState{} // reset the value
}
}
if !inProgress {
if breakNextState {
stopPolling()
return
}
breakNextState = true
}
}
err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace, poller)
if err != nil {
if errors.Is(err, context.Canceled) && breakNextState {
return nil // we cancelled the context to stop polling, we can ignore the error
}
return fmt.Errorf("failed to poll state changes from codespace: %w", err)
}
return nil
}
type getUserResult struct {
User *api.User
Err error
}
// getUser fetches the user record associated with the GITHUB_TOKEN
func getUser(ctx context.Context, apiClient apiClient) <-chan getUserResult {
ch := make(chan getUserResult, 1)
go func() {
user, err := apiClient.GetUser(ctx)
ch <- getUserResult{user, err}
}()
return ch
}
type locationResult struct {
Location string
Err error
}
// getLocation fetches the closest Codespace datacenter region/location to the user.
func getLocation(ctx context.Context, apiClient apiClient) <-chan locationResult {
ch := make(chan locationResult, 1)
go func() {
location, err := apiClient.GetCodespaceRegionLocation(ctx)
ch <- locationResult{location, err}
}()
return ch
}
// getRepoName prompts the user for the name of the repository, or returns the repository if non-empty.
func getRepoName(repo string) (string, error) {
if repo != "" {
return repo, nil
}
repoSurvey := []*survey.Question{
{
Name: "repository",
Prompt: &survey.Input{Message: "Repository:"},
Validate: survey.Required,
},
}
err := ask(repoSurvey, &repo)
return repo, err
}
// getBranchName prompts the user for the name of the branch, or returns the branch if non-empty.
func getBranchName(branch string) (string, error) {
if branch != "" {
return branch, nil
}
branchSurvey := []*survey.Question{
{
Name: "branch",
Prompt: &survey.Input{Message: "Branch:"},
Validate: survey.Required,
},
}
err := ask(branchSurvey, &branch)
return branch, err
}
// getMachineName prompts the user to select the machine type, or validates the machine if non-empty.
func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient apiClient) (string, error) {
skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location)
if err != nil {
return "", fmt.Errorf("error requesting machine instance types: %w", err)
}
// if user supplied a machine type, it must be valid
// if no machine type was supplied, we don't error if there are no machine types for the current repo
if machine != "" {
for _, sku := range skus {
if machine == sku.Name {
return machine, nil
}
}
availableSKUs := make([]string, len(skus))
for i := 0; i < len(skus); i++ {
availableSKUs[i] = skus[i].Name
}
return "", fmt.Errorf("there is no such machine for the repository: %s\nAvailable machines: %v", machine, availableSKUs)
} else if len(skus) == 0 {
return "", nil
}
if len(skus) == 1 {
return skus[0].Name, nil // VS Code does not prompt for SKU if there is only one, this makes us consistent with that behavior
}
skuNames := make([]string, 0, len(skus))
skuByName := make(map[string]*api.SKU)
for _, sku := range skus {
nameParts := camelcase.Split(sku.Name)
machineName := strings.Title(strings.ToLower(nameParts[0]))
skuName := fmt.Sprintf("%s - %s", machineName, sku.DisplayName)
skuNames = append(skuNames, skuName)
skuByName[skuName] = sku
}
skuSurvey := []*survey.Question{
{
Name: "sku",
Prompt: &survey.Select{
Message: "Choose Machine Type:",
Options: skuNames,
Default: skuNames[0],
},
Validate: survey.Required,
},
}
var skuAnswers struct{ SKU string }
if err := ask(skuSurvey, &skuAnswers); err != nil {
return "", fmt.Errorf("error getting SKU: %w", err)
}
sku := skuByName[skuAnswers.SKU]
machine = sku.Name
return machine, nil
}

180
cmd/ghcs/delete.go Normal file
View file

@ -0,0 +1,180 @@
package ghcs
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/AlecAivazis/survey/v2"
"github.com/cli/cli/v2/internal/api"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
type deleteOptions struct {
deleteAll bool
skipConfirm bool
codespaceName string
repoFilter string
keepDays uint16
isInteractive bool
now func() time.Time
prompter prompter
}
//go:generate moq -fmt goimports -rm -skip-ensure -out mock_prompter.go . prompter
type prompter interface {
Confirm(message string) (bool, error)
}
func newDeleteCmd(app *App) *cobra.Command {
opts := deleteOptions{
isInteractive: hasTTY,
now: time.Now,
prompter: &surveyPrompter{},
}
deleteCmd := &cobra.Command{
Use: "delete",
Short: "Delete a codespace",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
if opts.deleteAll && opts.repoFilter != "" {
return errors.New("both --all and --repo is not supported")
}
return app.Delete(cmd.Context(), opts)
},
}
deleteCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "Name of the codespace")
deleteCmd.Flags().BoolVar(&opts.deleteAll, "all", false, "Delete all codespaces")
deleteCmd.Flags().StringVarP(&opts.repoFilter, "repo", "r", "", "Delete codespaces for a `repository`")
deleteCmd.Flags().BoolVarP(&opts.skipConfirm, "force", "f", false, "Skip confirmation for codespaces that contain unsaved changes")
deleteCmd.Flags().Uint16Var(&opts.keepDays, "days", 0, "Delete codespaces older than `N` days")
return deleteCmd
}
func (a *App) Delete(ctx context.Context, opts deleteOptions) error {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
var codespaces []*api.Codespace
nameFilter := opts.codespaceName
if nameFilter == "" {
codespaces, err = a.apiClient.ListCodespaces(ctx, user.Login)
if err != nil {
return fmt.Errorf("error getting codespaces: %w", err)
}
if !opts.deleteAll && opts.repoFilter == "" {
c, err := chooseCodespaceFromList(ctx, codespaces)
if err != nil {
return fmt.Errorf("error choosing codespace: %w", err)
}
nameFilter = c.Name
}
} else {
// TODO: this token is discarded and then re-requested later in DeleteCodespace
token, err := a.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter)
if err != nil {
return fmt.Errorf("error getting codespace token: %w", err)
}
codespace, err := a.apiClient.GetCodespace(ctx, token, user.Login, nameFilter)
if err != nil {
return fmt.Errorf("error fetching codespace information: %w", err)
}
codespaces = []*api.Codespace{codespace}
}
codespacesToDelete := make([]*api.Codespace, 0, len(codespaces))
lastUpdatedCutoffTime := opts.now().AddDate(0, 0, -int(opts.keepDays))
for _, c := range codespaces {
if nameFilter != "" && c.Name != nameFilter {
continue
}
if opts.repoFilter != "" && !strings.EqualFold(c.RepositoryNWO, opts.repoFilter) {
continue
}
if opts.keepDays > 0 {
t, err := time.Parse(time.RFC3339, c.LastUsedAt)
if err != nil {
return fmt.Errorf("error parsing last_used_at timestamp %q: %w", c.LastUsedAt, err)
}
if t.After(lastUpdatedCutoffTime) {
continue
}
}
if !opts.skipConfirm {
confirmed, err := confirmDeletion(opts.prompter, c, opts.isInteractive)
if err != nil {
return fmt.Errorf("unable to confirm: %w", err)
}
if !confirmed {
continue
}
}
codespacesToDelete = append(codespacesToDelete, c)
}
if len(codespacesToDelete) == 0 {
return errors.New("no codespaces to delete")
}
g := errgroup.Group{}
for _, c := range codespacesToDelete {
codespaceName := c.Name
g.Go(func() error {
if err := a.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil {
_, _ = a.logger.Errorf("error deleting codespace %q: %v\n", codespaceName, err)
return err
}
return nil
})
}
if err := g.Wait(); err != nil {
return errors.New("some codespaces failed to delete")
}
return nil
}
func confirmDeletion(p prompter, codespace *api.Codespace, isInteractive bool) (bool, error) {
gs := codespace.Environment.GitStatus
hasUnsavedChanges := gs.HasUncommitedChanges || gs.HasUnpushedChanges
if !hasUnsavedChanges {
return true, nil
}
if !isInteractive {
return false, fmt.Errorf("codespace %s has unsaved changes (use --force to override)", codespace.Name)
}
return p.Confirm(fmt.Sprintf("Codespace %s has unsaved changes. OK to delete?", codespace.Name))
}
type surveyPrompter struct{}
func (p *surveyPrompter) Confirm(message string) (bool, error) {
var confirmed struct {
Confirmed bool
}
q := []*survey.Question{
{
Name: "confirmed",
Prompt: &survey.Confirm{
Message: message,
},
},
}
if err := ask(q, &confirmed); err != nil {
return false, fmt.Errorf("failed to prompt: %w", err)
}
return confirmed.Confirmed, nil
}

257
cmd/ghcs/delete_test.go Normal file
View file

@ -0,0 +1,257 @@
package ghcs
import (
"bytes"
"context"
"errors"
"fmt"
"sort"
"strings"
"testing"
"time"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/cmd/ghcs/output"
"github.com/cli/cli/v2/internal/api"
)
func TestDelete(t *testing.T) {
user := &api.User{Login: "hubot"}
now, _ := time.Parse(time.RFC3339, "2021-09-22T00:00:00Z")
daysAgo := func(n int) string {
return now.Add(time.Hour * -time.Duration(24*n)).Format(time.RFC3339)
}
tests := []struct {
name string
opts deleteOptions
codespaces []*api.Codespace
confirms map[string]bool
deleteErr error
wantErr bool
wantDeleted []string
wantStdout string
wantStderr string
}{
{
name: "by name",
opts: deleteOptions{
codespaceName: "hubot-robawt-abc",
},
codespaces: []*api.Codespace{
{
Name: "hubot-robawt-abc",
},
},
wantDeleted: []string{"hubot-robawt-abc"},
},
{
name: "by repo",
opts: deleteOptions{
repoFilter: "monalisa/spoon-knife",
},
codespaces: []*api.Codespace{
{
Name: "monalisa-spoonknife-123",
RepositoryNWO: "monalisa/Spoon-Knife",
},
{
Name: "hubot-robawt-abc",
RepositoryNWO: "hubot/ROBAWT",
},
{
Name: "monalisa-spoonknife-c4f3",
RepositoryNWO: "monalisa/Spoon-Knife",
},
},
wantDeleted: []string{"monalisa-spoonknife-123", "monalisa-spoonknife-c4f3"},
},
{
name: "unused",
opts: deleteOptions{
deleteAll: true,
keepDays: 3,
},
codespaces: []*api.Codespace{
{
Name: "monalisa-spoonknife-123",
LastUsedAt: daysAgo(1),
},
{
Name: "hubot-robawt-abc",
LastUsedAt: daysAgo(4),
},
{
Name: "monalisa-spoonknife-c4f3",
LastUsedAt: daysAgo(10),
},
},
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"},
},
{
name: "deletion failed",
opts: deleteOptions{
deleteAll: true,
},
codespaces: []*api.Codespace{
{
Name: "monalisa-spoonknife-123",
},
{
Name: "hubot-robawt-abc",
},
},
deleteErr: errors.New("aborted by test"),
wantErr: true,
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-123"},
wantStderr: heredoc.Doc(`
error deleting codespace "hubot-robawt-abc": aborted by test
error deleting codespace "monalisa-spoonknife-123": aborted by test
`),
},
{
name: "with confirm",
opts: deleteOptions{
isInteractive: true,
deleteAll: true,
skipConfirm: false,
},
codespaces: []*api.Codespace{
{
Name: "monalisa-spoonknife-123",
Environment: api.CodespaceEnvironment{
GitStatus: api.CodespaceEnvironmentGitStatus{
HasUnpushedChanges: true,
},
},
},
{
Name: "hubot-robawt-abc",
Environment: api.CodespaceEnvironment{
GitStatus: api.CodespaceEnvironmentGitStatus{
HasUncommitedChanges: true,
},
},
},
{
Name: "monalisa-spoonknife-c4f3",
Environment: api.CodespaceEnvironment{
GitStatus: api.CodespaceEnvironmentGitStatus{
HasUnpushedChanges: false,
HasUncommitedChanges: false,
},
},
},
},
confirms: map[string]bool{
"Codespace monalisa-spoonknife-123 has unsaved changes. OK to delete?": false,
"Codespace hubot-robawt-abc has unsaved changes. OK to delete?": true,
},
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
apiMock := &apiClientMock{
GetUserFunc: func(_ context.Context) (*api.User, error) {
return user, nil
},
DeleteCodespaceFunc: func(_ context.Context, userLogin, name string) error {
if userLogin != user.Login {
return fmt.Errorf("unexpected user %q", userLogin)
}
if tt.deleteErr != nil {
return tt.deleteErr
}
return nil
},
}
if tt.opts.codespaceName == "" {
apiMock.ListCodespacesFunc = func(_ context.Context, userLogin string) ([]*api.Codespace, error) {
if userLogin != user.Login {
return nil, fmt.Errorf("unexpected user %q", userLogin)
}
return tt.codespaces, nil
}
} else {
apiMock.GetCodespaceTokenFunc = func(_ context.Context, userLogin, name string) (string, error) {
if userLogin != user.Login {
return "", fmt.Errorf("unexpected user %q", userLogin)
}
return "CS_TOKEN", nil
}
apiMock.GetCodespaceFunc = func(_ context.Context, token, userLogin, name string) (*api.Codespace, error) {
if userLogin != user.Login {
return nil, fmt.Errorf("unexpected user %q", userLogin)
}
if token != "CS_TOKEN" {
return nil, fmt.Errorf("unexpected token %q", token)
}
return tt.codespaces[0], nil
}
}
opts := tt.opts
opts.now = func() time.Time { return now }
opts.prompter = &prompterMock{
ConfirmFunc: func(msg string) (bool, error) {
res, found := tt.confirms[msg]
if !found {
return false, fmt.Errorf("unexpected prompt %q", msg)
}
return res, nil
},
}
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
app := &App{
apiClient: apiMock,
logger: output.NewLogger(stdout, stderr, false),
}
err := app.Delete(context.Background(), opts)
if (err != nil) != tt.wantErr {
t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr)
}
if n := len(apiMock.GetUserCalls()); n != 1 {
t.Errorf("GetUser invoked %d times, expected %d", n, 1)
}
var gotDeleted []string
for _, delArgs := range apiMock.DeleteCodespaceCalls() {
gotDeleted = append(gotDeleted, delArgs.Name)
}
sort.Strings(gotDeleted)
if !sliceEquals(gotDeleted, tt.wantDeleted) {
t.Errorf("deleted %q, want %q", gotDeleted, tt.wantDeleted)
}
if out := stdout.String(); out != tt.wantStdout {
t.Errorf("stdout = %q, want %q", out, tt.wantStdout)
}
if out := sortLines(stderr.String()); out != tt.wantStderr {
t.Errorf("stderr = %q, want %q", out, tt.wantStderr)
}
})
}
}
func sliceEquals(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func sortLines(s string) string {
trailing := ""
if strings.HasSuffix(s, "\n") {
s = strings.TrimSuffix(s, "\n")
trailing = "\n"
}
lines := strings.Split(s, "\n")
sort.Strings(lines)
return strings.Join(lines, "\n") + trailing
}

63
cmd/ghcs/list.go Normal file
View file

@ -0,0 +1,63 @@
package ghcs
import (
"context"
"fmt"
"os"
"github.com/cli/cli/v2/cmd/ghcs/output"
"github.com/cli/cli/v2/internal/api"
"github.com/spf13/cobra"
)
func newListCmd(app *App) *cobra.Command {
var asJSON bool
listCmd := &cobra.Command{
Use: "list",
Short: "List your codespaces",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return app.List(cmd.Context(), asJSON)
},
}
listCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON")
return listCmd
}
func (a *App) List(ctx context.Context, asJSON bool) error {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespaces, err := a.apiClient.ListCodespaces(ctx, user.Login)
if err != nil {
return fmt.Errorf("error getting codespaces: %w", err)
}
table := output.NewTable(os.Stdout, asJSON)
table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"})
for _, codespace := range codespaces {
table.Append([]string{
codespace.Name,
codespace.RepositoryNWO,
codespace.Branch + dirtyStar(codespace.Environment.GitStatus),
codespace.Environment.State,
codespace.CreatedAt,
})
}
table.Render()
return nil
}
func dirtyStar(status api.CodespaceEnvironmentGitStatus) string {
if status.HasUncommitedChanges || status.HasUnpushedChanges {
return "*"
}
return ""
}

112
cmd/ghcs/logs.go Normal file
View file

@ -0,0 +1,112 @@
package ghcs
import (
"context"
"fmt"
"net"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/liveshare"
"github.com/spf13/cobra"
)
func newLogsCmd(app *App) *cobra.Command {
var (
codespace string
follow bool
)
logsCmd := &cobra.Command{
Use: "logs",
Short: "Access codespace logs",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return app.Logs(cmd.Context(), codespace, follow)
},
}
logsCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace")
logsCmd.Flags().BoolVarP(&follow, "follow", "f", false, "Tail and follow the logs")
return logsCmd
}
func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err error) {
// Ensure all child tasks (port forwarding, remote exec) terminate before return.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("getting user: %w", err)
}
authkeys := make(chan error, 1)
go func() {
authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login)
}()
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
return fmt.Errorf("get or choose codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
if err := <-authkeys; err != nil {
return err
}
// Ensure local port is listening before client (getPostCreateOutput) connects.
listen, err := net.Listen("tcp", ":0") // arbitrary port
if err != nil {
return err
}
defer listen.Close()
localPort := listen.Addr().(*net.TCPAddr).Port
a.logger.Println("Fetching SSH Details...")
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)
}
cmdType := "cat"
if follow {
cmdType = "tail -f"
}
dst := fmt.Sprintf("%s@localhost", sshUser)
cmd, err := codespaces.NewRemoteCommand(
ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
)
if err != nil {
return fmt.Errorf("remote command: %w", err)
}
tunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
}()
cmdDone := make(chan error, 1)
go func() {
cmdDone <- cmd.Run()
}()
select {
case err := <-tunnelClosed:
return fmt.Errorf("connection closed: %w", err)
case err := <-cmdDone:
if err != nil {
return fmt.Errorf("error retrieving logs: %w", err)
}
return nil // success
}
}

54
cmd/ghcs/main/main.go Normal file
View file

@ -0,0 +1,54 @@
package main
import (
"errors"
"fmt"
"io"
"net/http"
"os"
"github.com/cli/cli/v2/cmd/ghcs"
"github.com/cli/cli/v2/cmd/ghcs/output"
"github.com/cli/cli/v2/internal/api"
"github.com/spf13/cobra"
)
func main() {
token := os.Getenv("GITHUB_TOKEN")
rootCmd := ghcs.NewRootCmd(ghcs.NewApp(
output.NewLogger(os.Stdout, os.Stderr, false),
api.New(token, http.DefaultClient),
))
// Require GITHUB_TOKEN through a Cobra pre-run hook so that Cobra's help system for commands can still
// function without the token set.
oldPreRun := rootCmd.PersistentPreRunE
rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
if token == "" {
return errTokenMissing
}
if oldPreRun != nil {
return oldPreRun(cmd, args)
}
return nil
}
if cmd, err := rootCmd.ExecuteC(); err != nil {
explainError(os.Stderr, err, cmd)
os.Exit(1)
}
}
var errTokenMissing = errors.New("GITHUB_TOKEN is missing")
func explainError(w io.Writer, err error, cmd *cobra.Command) {
if errors.Is(err, errTokenMissing) {
fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo")
fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.")
return
}
if errors.Is(err, ghcs.ErrTooManyArgs) {
_ = cmd.Usage()
return
}
}

659
cmd/ghcs/mock_api.go Normal file
View file

@ -0,0 +1,659 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package ghcs
import (
"context"
"sync"
"github.com/cli/cli/v2/internal/api"
)
// apiClientMock is a mock implementation of apiClient.
//
// func TestSomethingThatUsesapiClient(t *testing.T) {
//
// // make and configure a mocked apiClient
// mockedapiClient := &apiClientMock{
// AuthorizedKeysFunc: func(ctx context.Context, user string) ([]byte, error) {
// panic("mock out the AuthorizedKeys method")
// },
// CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
// panic("mock out the CreateCodespace method")
// },
// DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error {
// panic("mock out the DeleteCodespace method")
// },
// GetCodespaceFunc: func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) {
// panic("mock out the GetCodespace method")
// },
// GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) {
// panic("mock out the GetCodespaceRegionLocation method")
// },
// GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) {
// panic("mock out the GetCodespaceRepositoryContents method")
// },
// GetCodespaceTokenFunc: func(ctx context.Context, user string, name string) (string, error) {
// panic("mock out the GetCodespaceToken method")
// },
// GetCodespacesSKUsFunc: func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) {
// panic("mock out the GetCodespacesSKUs method")
// },
// GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
// panic("mock out the GetRepository method")
// },
// GetUserFunc: func(ctx context.Context) (*api.User, error) {
// panic("mock out the GetUser method")
// },
// ListCodespacesFunc: func(ctx context.Context, user string) ([]*api.Codespace, error) {
// panic("mock out the ListCodespaces method")
// },
// StartCodespaceFunc: func(ctx context.Context, token string, codespace *api.Codespace) error {
// panic("mock out the StartCodespace method")
// },
// }
//
// // use mockedapiClient in code that requires apiClient
// // and then make assertions.
//
// }
type apiClientMock struct {
// AuthorizedKeysFunc mocks the AuthorizedKeys method.
AuthorizedKeysFunc func(ctx context.Context, user string) ([]byte, error)
// CreateCodespaceFunc mocks the CreateCodespace method.
CreateCodespaceFunc func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
// DeleteCodespaceFunc mocks the DeleteCodespace method.
DeleteCodespaceFunc func(ctx context.Context, user string, name string) error
// GetCodespaceFunc mocks the GetCodespace method.
GetCodespaceFunc func(ctx context.Context, token string, user string, name string) (*api.Codespace, error)
// GetCodespaceRegionLocationFunc mocks the GetCodespaceRegionLocation method.
GetCodespaceRegionLocationFunc func(ctx context.Context) (string, error)
// GetCodespaceRepositoryContentsFunc mocks the GetCodespaceRepositoryContents method.
GetCodespaceRepositoryContentsFunc func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
// GetCodespaceTokenFunc mocks the GetCodespaceToken method.
GetCodespaceTokenFunc func(ctx context.Context, user string, name string) (string, error)
// GetCodespacesSKUsFunc mocks the GetCodespacesSKUs method.
GetCodespacesSKUsFunc func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error)
// GetRepositoryFunc mocks the GetRepository method.
GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error)
// GetUserFunc mocks the GetUser method.
GetUserFunc func(ctx context.Context) (*api.User, error)
// ListCodespacesFunc mocks the ListCodespaces method.
ListCodespacesFunc func(ctx context.Context, user string) ([]*api.Codespace, error)
// StartCodespaceFunc mocks the StartCodespace method.
StartCodespaceFunc func(ctx context.Context, token string, codespace *api.Codespace) error
// calls tracks calls to the methods.
calls struct {
// AuthorizedKeys holds details about calls to the AuthorizedKeys method.
AuthorizedKeys []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User string
}
// CreateCodespace holds details about calls to the CreateCodespace method.
CreateCodespace []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Params is the params argument value.
Params *api.CreateCodespaceParams
}
// DeleteCodespace holds details about calls to the DeleteCodespace method.
DeleteCodespace []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User string
// Name is the name argument value.
Name string
}
// GetCodespace holds details about calls to the GetCodespace method.
GetCodespace []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Token is the token argument value.
Token string
// User is the user argument value.
User string
// Name is the name argument value.
Name string
}
// GetCodespaceRegionLocation holds details about calls to the GetCodespaceRegionLocation method.
GetCodespaceRegionLocation []struct {
// Ctx is the ctx argument value.
Ctx context.Context
}
// GetCodespaceRepositoryContents holds details about calls to the GetCodespaceRepositoryContents method.
GetCodespaceRepositoryContents []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Codespace is the codespace argument value.
Codespace *api.Codespace
// Path is the path argument value.
Path string
}
// GetCodespaceToken holds details about calls to the GetCodespaceToken method.
GetCodespaceToken []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User string
// Name is the name argument value.
Name string
}
// GetCodespacesSKUs holds details about calls to the GetCodespacesSKUs method.
GetCodespacesSKUs []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User *api.User
// Repository is the repository argument value.
Repository *api.Repository
// Branch is the branch argument value.
Branch string
// Location is the location argument value.
Location string
}
// GetRepository holds details about calls to the GetRepository method.
GetRepository []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Nwo is the nwo argument value.
Nwo string
}
// GetUser holds details about calls to the GetUser method.
GetUser []struct {
// Ctx is the ctx argument value.
Ctx context.Context
}
// ListCodespaces holds details about calls to the ListCodespaces method.
ListCodespaces []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User string
}
// StartCodespace holds details about calls to the StartCodespace method.
StartCodespace []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Token is the token argument value.
Token string
// Codespace is the codespace argument value.
Codespace *api.Codespace
}
}
lockAuthorizedKeys sync.RWMutex
lockCreateCodespace sync.RWMutex
lockDeleteCodespace sync.RWMutex
lockGetCodespace sync.RWMutex
lockGetCodespaceRegionLocation sync.RWMutex
lockGetCodespaceRepositoryContents sync.RWMutex
lockGetCodespaceToken sync.RWMutex
lockGetCodespacesSKUs sync.RWMutex
lockGetRepository sync.RWMutex
lockGetUser sync.RWMutex
lockListCodespaces sync.RWMutex
lockStartCodespace sync.RWMutex
}
// AuthorizedKeys calls AuthorizedKeysFunc.
func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
if mock.AuthorizedKeysFunc == nil {
panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called")
}
callInfo := struct {
Ctx context.Context
User string
}{
Ctx: ctx,
User: user,
}
mock.lockAuthorizedKeys.Lock()
mock.calls.AuthorizedKeys = append(mock.calls.AuthorizedKeys, callInfo)
mock.lockAuthorizedKeys.Unlock()
return mock.AuthorizedKeysFunc(ctx, user)
}
// AuthorizedKeysCalls gets all the calls that were made to AuthorizedKeys.
// Check the length with:
// len(mockedapiClient.AuthorizedKeysCalls())
func (mock *apiClientMock) AuthorizedKeysCalls() []struct {
Ctx context.Context
User string
} {
var calls []struct {
Ctx context.Context
User string
}
mock.lockAuthorizedKeys.RLock()
calls = mock.calls.AuthorizedKeys
mock.lockAuthorizedKeys.RUnlock()
return calls
}
// CreateCodespace calls CreateCodespaceFunc.
func (mock *apiClientMock) CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
if mock.CreateCodespaceFunc == nil {
panic("apiClientMock.CreateCodespaceFunc: method is nil but apiClient.CreateCodespace was just called")
}
callInfo := struct {
Ctx context.Context
Params *api.CreateCodespaceParams
}{
Ctx: ctx,
Params: params,
}
mock.lockCreateCodespace.Lock()
mock.calls.CreateCodespace = append(mock.calls.CreateCodespace, callInfo)
mock.lockCreateCodespace.Unlock()
return mock.CreateCodespaceFunc(ctx, params)
}
// CreateCodespaceCalls gets all the calls that were made to CreateCodespace.
// Check the length with:
// len(mockedapiClient.CreateCodespaceCalls())
func (mock *apiClientMock) CreateCodespaceCalls() []struct {
Ctx context.Context
Params *api.CreateCodespaceParams
} {
var calls []struct {
Ctx context.Context
Params *api.CreateCodespaceParams
}
mock.lockCreateCodespace.RLock()
calls = mock.calls.CreateCodespace
mock.lockCreateCodespace.RUnlock()
return calls
}
// DeleteCodespace calls DeleteCodespaceFunc.
func (mock *apiClientMock) DeleteCodespace(ctx context.Context, user string, name string) error {
if mock.DeleteCodespaceFunc == nil {
panic("apiClientMock.DeleteCodespaceFunc: method is nil but apiClient.DeleteCodespace was just called")
}
callInfo := struct {
Ctx context.Context
User string
Name string
}{
Ctx: ctx,
User: user,
Name: name,
}
mock.lockDeleteCodespace.Lock()
mock.calls.DeleteCodespace = append(mock.calls.DeleteCodespace, callInfo)
mock.lockDeleteCodespace.Unlock()
return mock.DeleteCodespaceFunc(ctx, user, name)
}
// DeleteCodespaceCalls gets all the calls that were made to DeleteCodespace.
// Check the length with:
// len(mockedapiClient.DeleteCodespaceCalls())
func (mock *apiClientMock) DeleteCodespaceCalls() []struct {
Ctx context.Context
User string
Name string
} {
var calls []struct {
Ctx context.Context
User string
Name string
}
mock.lockDeleteCodespace.RLock()
calls = mock.calls.DeleteCodespace
mock.lockDeleteCodespace.RUnlock()
return calls
}
// GetCodespace calls GetCodespaceFunc.
func (mock *apiClientMock) GetCodespace(ctx context.Context, token string, user string, name string) (*api.Codespace, error) {
if mock.GetCodespaceFunc == nil {
panic("apiClientMock.GetCodespaceFunc: method is nil but apiClient.GetCodespace was just called")
}
callInfo := struct {
Ctx context.Context
Token string
User string
Name string
}{
Ctx: ctx,
Token: token,
User: user,
Name: name,
}
mock.lockGetCodespace.Lock()
mock.calls.GetCodespace = append(mock.calls.GetCodespace, callInfo)
mock.lockGetCodespace.Unlock()
return mock.GetCodespaceFunc(ctx, token, user, name)
}
// GetCodespaceCalls gets all the calls that were made to GetCodespace.
// Check the length with:
// len(mockedapiClient.GetCodespaceCalls())
func (mock *apiClientMock) GetCodespaceCalls() []struct {
Ctx context.Context
Token string
User string
Name string
} {
var calls []struct {
Ctx context.Context
Token string
User string
Name string
}
mock.lockGetCodespace.RLock()
calls = mock.calls.GetCodespace
mock.lockGetCodespace.RUnlock()
return calls
}
// GetCodespaceRegionLocation calls GetCodespaceRegionLocationFunc.
func (mock *apiClientMock) GetCodespaceRegionLocation(ctx context.Context) (string, error) {
if mock.GetCodespaceRegionLocationFunc == nil {
panic("apiClientMock.GetCodespaceRegionLocationFunc: method is nil but apiClient.GetCodespaceRegionLocation was just called")
}
callInfo := struct {
Ctx context.Context
}{
Ctx: ctx,
}
mock.lockGetCodespaceRegionLocation.Lock()
mock.calls.GetCodespaceRegionLocation = append(mock.calls.GetCodespaceRegionLocation, callInfo)
mock.lockGetCodespaceRegionLocation.Unlock()
return mock.GetCodespaceRegionLocationFunc(ctx)
}
// GetCodespaceRegionLocationCalls gets all the calls that were made to GetCodespaceRegionLocation.
// Check the length with:
// len(mockedapiClient.GetCodespaceRegionLocationCalls())
func (mock *apiClientMock) GetCodespaceRegionLocationCalls() []struct {
Ctx context.Context
} {
var calls []struct {
Ctx context.Context
}
mock.lockGetCodespaceRegionLocation.RLock()
calls = mock.calls.GetCodespaceRegionLocation
mock.lockGetCodespaceRegionLocation.RUnlock()
return calls
}
// GetCodespaceRepositoryContents calls GetCodespaceRepositoryContentsFunc.
func (mock *apiClientMock) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) {
if mock.GetCodespaceRepositoryContentsFunc == nil {
panic("apiClientMock.GetCodespaceRepositoryContentsFunc: method is nil but apiClient.GetCodespaceRepositoryContents was just called")
}
callInfo := struct {
Ctx context.Context
Codespace *api.Codespace
Path string
}{
Ctx: ctx,
Codespace: codespace,
Path: path,
}
mock.lockGetCodespaceRepositoryContents.Lock()
mock.calls.GetCodespaceRepositoryContents = append(mock.calls.GetCodespaceRepositoryContents, callInfo)
mock.lockGetCodespaceRepositoryContents.Unlock()
return mock.GetCodespaceRepositoryContentsFunc(ctx, codespace, path)
}
// GetCodespaceRepositoryContentsCalls gets all the calls that were made to GetCodespaceRepositoryContents.
// Check the length with:
// len(mockedapiClient.GetCodespaceRepositoryContentsCalls())
func (mock *apiClientMock) GetCodespaceRepositoryContentsCalls() []struct {
Ctx context.Context
Codespace *api.Codespace
Path string
} {
var calls []struct {
Ctx context.Context
Codespace *api.Codespace
Path string
}
mock.lockGetCodespaceRepositoryContents.RLock()
calls = mock.calls.GetCodespaceRepositoryContents
mock.lockGetCodespaceRepositoryContents.RUnlock()
return calls
}
// GetCodespaceToken calls GetCodespaceTokenFunc.
func (mock *apiClientMock) GetCodespaceToken(ctx context.Context, user string, name string) (string, error) {
if mock.GetCodespaceTokenFunc == nil {
panic("apiClientMock.GetCodespaceTokenFunc: method is nil but apiClient.GetCodespaceToken was just called")
}
callInfo := struct {
Ctx context.Context
User string
Name string
}{
Ctx: ctx,
User: user,
Name: name,
}
mock.lockGetCodespaceToken.Lock()
mock.calls.GetCodespaceToken = append(mock.calls.GetCodespaceToken, callInfo)
mock.lockGetCodespaceToken.Unlock()
return mock.GetCodespaceTokenFunc(ctx, user, name)
}
// GetCodespaceTokenCalls gets all the calls that were made to GetCodespaceToken.
// Check the length with:
// len(mockedapiClient.GetCodespaceTokenCalls())
func (mock *apiClientMock) GetCodespaceTokenCalls() []struct {
Ctx context.Context
User string
Name string
} {
var calls []struct {
Ctx context.Context
User string
Name string
}
mock.lockGetCodespaceToken.RLock()
calls = mock.calls.GetCodespaceToken
mock.lockGetCodespaceToken.RUnlock()
return calls
}
// GetCodespacesSKUs calls GetCodespacesSKUsFunc.
func (mock *apiClientMock) GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) {
if mock.GetCodespacesSKUsFunc == nil {
panic("apiClientMock.GetCodespacesSKUsFunc: method is nil but apiClient.GetCodespacesSKUs was just called")
}
callInfo := struct {
Ctx context.Context
User *api.User
Repository *api.Repository
Branch string
Location string
}{
Ctx: ctx,
User: user,
Repository: repository,
Branch: branch,
Location: location,
}
mock.lockGetCodespacesSKUs.Lock()
mock.calls.GetCodespacesSKUs = append(mock.calls.GetCodespacesSKUs, callInfo)
mock.lockGetCodespacesSKUs.Unlock()
return mock.GetCodespacesSKUsFunc(ctx, user, repository, branch, location)
}
// GetCodespacesSKUsCalls gets all the calls that were made to GetCodespacesSKUs.
// Check the length with:
// len(mockedapiClient.GetCodespacesSKUsCalls())
func (mock *apiClientMock) GetCodespacesSKUsCalls() []struct {
Ctx context.Context
User *api.User
Repository *api.Repository
Branch string
Location string
} {
var calls []struct {
Ctx context.Context
User *api.User
Repository *api.Repository
Branch string
Location string
}
mock.lockGetCodespacesSKUs.RLock()
calls = mock.calls.GetCodespacesSKUs
mock.lockGetCodespacesSKUs.RUnlock()
return calls
}
// GetRepository calls GetRepositoryFunc.
func (mock *apiClientMock) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) {
if mock.GetRepositoryFunc == nil {
panic("apiClientMock.GetRepositoryFunc: method is nil but apiClient.GetRepository was just called")
}
callInfo := struct {
Ctx context.Context
Nwo string
}{
Ctx: ctx,
Nwo: nwo,
}
mock.lockGetRepository.Lock()
mock.calls.GetRepository = append(mock.calls.GetRepository, callInfo)
mock.lockGetRepository.Unlock()
return mock.GetRepositoryFunc(ctx, nwo)
}
// GetRepositoryCalls gets all the calls that were made to GetRepository.
// Check the length with:
// len(mockedapiClient.GetRepositoryCalls())
func (mock *apiClientMock) GetRepositoryCalls() []struct {
Ctx context.Context
Nwo string
} {
var calls []struct {
Ctx context.Context
Nwo string
}
mock.lockGetRepository.RLock()
calls = mock.calls.GetRepository
mock.lockGetRepository.RUnlock()
return calls
}
// GetUser calls GetUserFunc.
func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) {
if mock.GetUserFunc == nil {
panic("apiClientMock.GetUserFunc: method is nil but apiClient.GetUser was just called")
}
callInfo := struct {
Ctx context.Context
}{
Ctx: ctx,
}
mock.lockGetUser.Lock()
mock.calls.GetUser = append(mock.calls.GetUser, callInfo)
mock.lockGetUser.Unlock()
return mock.GetUserFunc(ctx)
}
// GetUserCalls gets all the calls that were made to GetUser.
// Check the length with:
// len(mockedapiClient.GetUserCalls())
func (mock *apiClientMock) GetUserCalls() []struct {
Ctx context.Context
} {
var calls []struct {
Ctx context.Context
}
mock.lockGetUser.RLock()
calls = mock.calls.GetUser
mock.lockGetUser.RUnlock()
return calls
}
// ListCodespaces calls ListCodespacesFunc.
func (mock *apiClientMock) ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) {
if mock.ListCodespacesFunc == nil {
panic("apiClientMock.ListCodespacesFunc: method is nil but apiClient.ListCodespaces was just called")
}
callInfo := struct {
Ctx context.Context
User string
}{
Ctx: ctx,
User: user,
}
mock.lockListCodespaces.Lock()
mock.calls.ListCodespaces = append(mock.calls.ListCodespaces, callInfo)
mock.lockListCodespaces.Unlock()
return mock.ListCodespacesFunc(ctx, user)
}
// ListCodespacesCalls gets all the calls that were made to ListCodespaces.
// Check the length with:
// len(mockedapiClient.ListCodespacesCalls())
func (mock *apiClientMock) ListCodespacesCalls() []struct {
Ctx context.Context
User string
} {
var calls []struct {
Ctx context.Context
User string
}
mock.lockListCodespaces.RLock()
calls = mock.calls.ListCodespaces
mock.lockListCodespaces.RUnlock()
return calls
}
// StartCodespace calls StartCodespaceFunc.
func (mock *apiClientMock) StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error {
if mock.StartCodespaceFunc == nil {
panic("apiClientMock.StartCodespaceFunc: method is nil but apiClient.StartCodespace was just called")
}
callInfo := struct {
Ctx context.Context
Token string
Codespace *api.Codespace
}{
Ctx: ctx,
Token: token,
Codespace: codespace,
}
mock.lockStartCodespace.Lock()
mock.calls.StartCodespace = append(mock.calls.StartCodespace, callInfo)
mock.lockStartCodespace.Unlock()
return mock.StartCodespaceFunc(ctx, token, codespace)
}
// StartCodespaceCalls gets all the calls that were made to StartCodespace.
// Check the length with:
// len(mockedapiClient.StartCodespaceCalls())
func (mock *apiClientMock) StartCodespaceCalls() []struct {
Ctx context.Context
Token string
Codespace *api.Codespace
} {
var calls []struct {
Ctx context.Context
Token string
Codespace *api.Codespace
}
mock.lockStartCodespace.RLock()
calls = mock.calls.StartCodespace
mock.lockStartCodespace.RUnlock()
return calls
}

69
cmd/ghcs/mock_prompter.go Normal file
View file

@ -0,0 +1,69 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package ghcs
import (
"sync"
)
// prompterMock is a mock implementation of prompter.
//
// func TestSomethingThatUsesprompter(t *testing.T) {
//
// // make and configure a mocked prompter
// mockedprompter := &prompterMock{
// ConfirmFunc: func(message string) (bool, error) {
// panic("mock out the Confirm method")
// },
// }
//
// // use mockedprompter in code that requires prompter
// // and then make assertions.
//
// }
type prompterMock struct {
// ConfirmFunc mocks the Confirm method.
ConfirmFunc func(message string) (bool, error)
// calls tracks calls to the methods.
calls struct {
// Confirm holds details about calls to the Confirm method.
Confirm []struct {
// Message is the message argument value.
Message string
}
}
lockConfirm sync.RWMutex
}
// Confirm calls ConfirmFunc.
func (mock *prompterMock) Confirm(message string) (bool, error) {
if mock.ConfirmFunc == nil {
panic("prompterMock.ConfirmFunc: method is nil but prompter.Confirm was just called")
}
callInfo := struct {
Message string
}{
Message: message,
}
mock.lockConfirm.Lock()
mock.calls.Confirm = append(mock.calls.Confirm, callInfo)
mock.lockConfirm.Unlock()
return mock.ConfirmFunc(message)
}
// ConfirmCalls gets all the calls that were made to Confirm.
// Check the length with:
// len(mockedprompter.ConfirmCalls())
func (mock *prompterMock) ConfirmCalls() []struct {
Message string
} {
var calls []struct {
Message string
}
mock.lockConfirm.RLock()
calls = mock.calls.Confirm
mock.lockConfirm.RUnlock()
return calls
}

View file

@ -0,0 +1,55 @@
package output
import (
"encoding/json"
"io"
"strings"
"unicode"
)
type jsonwriter struct {
w io.Writer
pretty bool
cols []string
data []interface{}
}
func (j *jsonwriter) SetHeader(cols []string) {
j.cols = cols
}
func (j *jsonwriter) Append(values []string) {
row := make(map[string]string)
for i, v := range values {
row[camelize(j.cols[i])] = v
}
j.data = append(j.data, row)
}
func (j *jsonwriter) Render() {
enc := json.NewEncoder(j.w)
if j.pretty {
enc.SetIndent("", " ")
}
_ = enc.Encode(j.data)
}
func camelize(s string) string {
var b strings.Builder
capitalizeNext := false
for i, r := range s {
if r == ' ' {
capitalizeNext = true
continue
}
if capitalizeNext {
b.WriteRune(unicode.ToUpper(r))
capitalizeNext = false
} else if i == 0 {
b.WriteRune(unicode.ToLower(r))
} else {
b.WriteRune(r)
}
}
return b.String()
}

View file

@ -0,0 +1,31 @@
package output
import (
"io"
"os"
"github.com/olekukonko/tablewriter"
"golang.org/x/term"
)
type Table interface {
SetHeader([]string)
Append([]string)
Render()
}
func NewTable(w io.Writer, asJSON bool) Table {
isTTY := isTTY(w)
if asJSON {
return &jsonwriter{w: w, pretty: isTTY}
}
if isTTY {
return tablewriter.NewWriter(w)
}
return &tabwriter{w: w}
}
func isTTY(w io.Writer) bool {
f, ok := w.(*os.File)
return ok && term.IsTerminal(int(f.Fd()))
}

View file

@ -0,0 +1,25 @@
package output
import (
"fmt"
"io"
)
type tabwriter struct {
w io.Writer
}
func (j *tabwriter) SetHeader([]string) {}
func (j *tabwriter) Append(values []string) {
var sep string
for i, v := range values {
if i == 1 {
sep = "\t"
}
fmt.Fprintf(j.w, "%s%s", sep, v)
}
fmt.Fprint(j.w, "\n")
}
func (j *tabwriter) Render() {}

74
cmd/ghcs/output/logger.go Normal file
View file

@ -0,0 +1,74 @@
package output
import (
"fmt"
"io"
"sync"
)
// NewLogger returns a Logger that will write to the given stdout/stderr writers.
// Disable the Logger to prevent it from writing to stdout in a TTY environment.
func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger {
return &Logger{
out: stdout,
errout: stderr,
enabled: !disabled && isTTY(stdout),
}
}
// Logger writes to the given stdout/stderr writers.
// If not enabled, Print functions will noop but Error functions will continue
// to write to the stderr writer.
type Logger struct {
mu sync.Mutex // guards the writers
out io.Writer
errout io.Writer
enabled bool
}
// Print writes the arguments to the stdout writer.
func (l *Logger) Print(v ...interface{}) (int, error) {
if !l.enabled {
return 0, nil
}
l.mu.Lock()
defer l.mu.Unlock()
return fmt.Fprint(l.out, v...)
}
// Println writes the arguments to the stdout writer with a newline at the end.
func (l *Logger) Println(v ...interface{}) (int, error) {
if !l.enabled {
return 0, nil
}
l.mu.Lock()
defer l.mu.Unlock()
return fmt.Fprintln(l.out, v...)
}
// Printf writes the formatted arguments to the stdout writer.
func (l *Logger) Printf(f string, v ...interface{}) (int, error) {
if !l.enabled {
return 0, nil
}
l.mu.Lock()
defer l.mu.Unlock()
return fmt.Fprintf(l.out, f, v...)
}
// Errorf writes the formatted arguments to the stderr writer.
func (l *Logger) Errorf(f string, v ...interface{}) (int, error) {
l.mu.Lock()
defer l.mu.Unlock()
return fmt.Fprintf(l.errout, f, v...)
}
// Errorln writes the arguments to the stderr writer with a newline at the end.
func (l *Logger) Errorln(v ...interface{}) (int, error) {
l.mu.Lock()
defer l.mu.Unlock()
return fmt.Fprintln(l.errout, v...)
}

330
cmd/ghcs/ports.go Normal file
View file

@ -0,0 +1,330 @@
package ghcs
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"os"
"strconv"
"strings"
"github.com/cli/cli/v2/cmd/ghcs/output"
"github.com/cli/cli/v2/internal/api"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/liveshare"
"github.com/muhammadmuzzammil1998/jsonc"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
// newPortsCmd returns a Cobra "ports" command that displays a table of available ports,
// according to the specified flags.
func newPortsCmd(app *App) *cobra.Command {
var (
codespace string
asJSON bool
)
portsCmd := &cobra.Command{
Use: "ports",
Short: "List ports in a codespace",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return app.ListPorts(cmd.Context(), codespace, asJSON)
},
}
portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace")
portsCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON")
portsCmd.AddCommand(newPortsPublicCmd(app))
portsCmd.AddCommand(newPortsPrivateCmd(app))
portsCmd.AddCommand(newPortsForwardCmd(app))
return portsCmd
}
// ListPorts lists known ports in a codespace.
func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool) (err error) {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
// TODO(josebalius): remove special handling of this error here and it other places
if err == errNoCodespaces {
return err
}
return fmt.Errorf("error choosing codespace: %w", err)
}
devContainerCh := getDevContainer(ctx, a.apiClient, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
a.logger.Println("Loading ports...")
ports, err := session.GetSharedServers(ctx)
if err != nil {
return fmt.Errorf("error getting ports of shared servers: %w", err)
}
devContainerResult := <-devContainerCh
if devContainerResult.err != nil {
// Warn about failure to read the devcontainer file. Not a ghcs command error.
_, _ = a.logger.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error())
}
table := output.NewTable(os.Stdout, asJSON)
table.SetHeader([]string{"Label", "Port", "Public", "Browse URL"})
for _, port := range ports {
sourcePort := strconv.Itoa(port.SourcePort)
var portName string
if devContainerResult.devContainer != nil {
if attributes, ok := devContainerResult.devContainer.PortAttributes[sourcePort]; ok {
portName = attributes.Label
}
}
table.Append([]string{
portName,
sourcePort,
strings.ToUpper(strconv.FormatBool(port.IsPublic)),
fmt.Sprintf("https://%s-%s.githubpreview.dev/", codespace.Name, sourcePort),
})
}
table.Render()
return nil
}
type devContainerResult struct {
devContainer *devContainer
err error
}
type devContainer struct {
PortAttributes map[string]portAttribute `json:"portsAttributes"`
}
type portAttribute struct {
Label string `json:"label"`
}
func getDevContainer(ctx context.Context, apiClient apiClient, codespace *api.Codespace) <-chan devContainerResult {
ch := make(chan devContainerResult, 1)
go func() {
contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json")
if err != nil {
ch <- devContainerResult{nil, fmt.Errorf("error getting content: %w", err)}
return
}
if contents == nil {
ch <- devContainerResult{nil, nil}
return
}
convertedJSON := normalizeJSON(jsonc.ToJSON(contents))
if !jsonc.Valid(convertedJSON) {
ch <- devContainerResult{nil, errors.New("failed to convert json to standard json")}
return
}
var container devContainer
if err := json.Unmarshal(convertedJSON, &container); err != nil {
ch <- devContainerResult{nil, fmt.Errorf("error unmarshaling: %w", err)}
return
}
ch <- devContainerResult{&container, nil}
}()
return ch
}
// newPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public.
func newPortsPublicCmd(app *App) *cobra.Command {
return &cobra.Command{
Use: "public <port>",
Short: "Mark port as public",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
codespace, err := cmd.Flags().GetString("codespace")
if err != nil {
// should only happen if flag is not defined
// or if the flag is not of string type
// since it's a persistent flag that we control it should never happen
return fmt.Errorf("get codespace flag: %w", err)
}
return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], true)
},
}
}
// newPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private.
func newPortsPrivateCmd(app *App) *cobra.Command {
return &cobra.Command{
Use: "private <port>",
Short: "Mark port as private",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
codespace, err := cmd.Flags().GetString("codespace")
if err != nil {
// should only happen if flag is not defined
// or if the flag is not of string type
// since it's a persistent flag that we control it should never happen
return fmt.Errorf("get codespace flag: %w", err)
}
return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], false)
},
}
}
func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePort string, public bool) (err error) {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
if err == errNoCodespaces {
return err
}
return fmt.Errorf("error getting codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
port, err := strconv.Atoi(sourcePort)
if err != nil {
return fmt.Errorf("error reading port number: %w", err)
}
if err := session.UpdateSharedVisibility(ctx, port, public); err != nil {
return fmt.Errorf("error update port to public: %w", err)
}
state := "PUBLIC"
if !public {
state = "PRIVATE"
}
a.logger.Printf("Port %s is now %s.\n", sourcePort, state)
return nil
}
// NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of
// port pairs from the codespace to localhost.
func newPortsForwardCmd(app *App) *cobra.Command {
return &cobra.Command{
Use: "forward <remote-port>:<local-port>...",
Short: "Forward ports",
Args: cobra.MinimumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
codespace, err := cmd.Flags().GetString("codespace")
if err != nil {
// should only happen if flag is not defined
// or if the flag is not of string type
// since it's a persistent flag that we control it should never happen
return fmt.Errorf("get codespace flag: %w", err)
}
return app.ForwardPorts(cmd.Context(), codespace, args)
},
}
}
func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []string) (err error) {
portPairs, err := getPortPairs(ports)
if err != nil {
return fmt.Errorf("get port pairs: %w", err)
}
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
if err == errNoCodespaces {
return err
}
return fmt.Errorf("error getting codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
// Run forwarding of all ports concurrently, aborting all of
// them at the first failure, including cancellation of the context.
group, ctx := errgroup.WithContext(ctx)
for _, pair := range portPairs {
pair := pair
group.Go(func() error {
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local))
if err != nil {
return err
}
defer listen.Close()
a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)
name := fmt.Sprintf("share-%d", pair.remote)
fwd := liveshare.NewPortForwarder(session, name, pair.remote)
return fwd.ForwardToListener(ctx, listen) // error always non-nil
})
}
return group.Wait() // first error
}
type portPair struct {
remote, local int
}
// getPortPairs parses a list of strings of form "%d:%d" into pairs of (remote, local) numbers.
func getPortPairs(ports []string) ([]portPair, error) {
pp := make([]portPair, 0, len(ports))
for _, portString := range ports {
parts := strings.Split(portString, ":")
if len(parts) < 2 {
return nil, fmt.Errorf("port pair: %q is not valid", portString)
}
remote, err := strconv.Atoi(parts[0])
if err != nil {
return pp, fmt.Errorf("convert remote port to int: %w", err)
}
local, err := strconv.Atoi(parts[1])
if err != nil {
return pp, fmt.Errorf("convert local port to int: %w", err)
}
pp = append(pp, portPair{remote, local})
}
return pp, nil
}
func normalizeJSON(j []byte) []byte {
// remove trailing commas
return bytes.ReplaceAll(j, []byte("},}"), []byte("}}"))
}

30
cmd/ghcs/root.go Normal file
View file

@ -0,0 +1,30 @@
package ghcs
import (
"github.com/spf13/cobra"
)
var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number.
func NewRootCmd(app *App) *cobra.Command {
root := &cobra.Command{
Use: "ghcs",
SilenceUsage: true, // don't print usage message after each error (see #80)
SilenceErrors: false, // print errors automatically so that main need not
Long: `Unofficial CLI tool to manage GitHub Codespaces.
Running commands requires the GITHUB_TOKEN environment variable to be set to a
token to access the GitHub API with.`,
Version: version,
}
root.AddCommand(newCodeCmd(app))
root.AddCommand(newCreateCmd(app))
root.AddCommand(newDeleteCmd(app))
root.AddCommand(newListCmd(app))
root.AddCommand(newLogsCmd(app))
root.AddCommand(newPortsCmd(app))
root.AddCommand(newSSHCmd(app))
return root
}

105
cmd/ghcs/ssh.go Normal file
View file

@ -0,0 +1,105 @@
package ghcs
import (
"context"
"fmt"
"net"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/liveshare"
"github.com/spf13/cobra"
)
func newSSHCmd(app *App) *cobra.Command {
var sshProfile, codespaceName string
var sshServerPort int
sshCmd := &cobra.Command{
Use: "ssh [flags] [--] [ssh-flags] [command]",
Short: "SSH into a codespace",
RunE: func(cmd *cobra.Command, args []string) error {
return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort)
},
}
sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "Name of the SSH profile to use")
sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number (0 => pick unused)")
sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Name of the codespace")
return sshCmd
}
// SSH opens an ssh session or runs an ssh command in a codespace.
func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) {
// Ensure all child tasks (e.g. port forwarding) terminate before return.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
authkeys := make(chan error, 1)
go func() {
authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login)
}()
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
return fmt.Errorf("get or choose codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
if err := <-authkeys; err != nil {
return err
}
a.logger.Println("Fetching SSH Details...")
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)
}
usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell
// Ensure local port is listening before client (Shell) connects.
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localSSHServerPort))
if err != nil {
return err
}
defer listen.Close()
localSSHServerPort = listen.Addr().(*net.TCPAddr).Port
connectDestination := sshProfile
if connectDestination == "" {
connectDestination = fmt.Sprintf("%s@localhost", sshUser)
}
a.logger.Println("Ready...")
tunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil
}()
shellClosed := make(chan error, 1)
go func() {
shellClosed <- codespaces.Shell(ctx, a.logger, sshArgs, localSSHServerPort, connectDestination, usingCustomPort)
}()
select {
case err := <-tunnelClosed:
return fmt.Errorf("tunnel closed: %w", err)
case err := <-shellClosed:
if err != nil {
return fmt.Errorf("shell closed: %w", err)
}
return nil // success
}
}

13
go.mod
View file

@ -12,9 +12,11 @@ require (
github.com/cli/safeexec v1.0.0
github.com/cpuguy83/go-md2man/v2 v2.0.0
github.com/creack/pty v1.1.13
github.com/fatih/camelcase v1.0.0
github.com/gabriel-vasile/mimetype v1.1.2
github.com/google/go-cmp v0.5.5
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/go-version v1.2.1
github.com/henvic/httpretty v0.0.6
github.com/itchyny/gojq v0.12.4
@ -24,17 +26,24 @@ require (
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
github.com/muesli/reflow v0.2.1-0.20210502190812-c80126ec2ad5
github.com/muesli/termenv v0.8.1
github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38
github.com/olekukonko/tablewriter v0.0.5
github.com/opentracing/opentracing-go v1.1.0
github.com/shurcooL/githubv4 v0.0.0-20200928013246-d292edc3691b
github.com/shurcooL/graphql v0.0.0-20181231061246-d48a9a75455f
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/sourcegraph/jsonrpc2 v0.1.0
github.com/spf13/cobra v1.2.1
github.com/spf13/pflag v1.0.5
github.com/stretchr/objx v0.1.1 // indirect
github.com/stretchr/testify v1.7.0
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1
golang.org/x/term v0.0.0-20210503060354-a79de5458b56
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
)
replace github.com/shurcooL/graphql => github.com/cli/shurcooL-graphql v0.0.0-20200707151639-0f7232a2bf7e
replace golang.org/x/crypto => github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03

28
go.sum
View file

@ -73,6 +73,8 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn
github.com/cli/browser v1.0.0/go.mod h1:IEWkHYbLjkhtjwwWlwTHW2lGxeS5gezEQBMLTwDHf5Q=
github.com/cli/browser v1.1.0 h1:xOZBfkfY9L9vMBgqb1YwRirGu6QFaQ5dP/vXt5ENSOY=
github.com/cli/browser v1.1.0/go.mod h1:HKMQAt9t12kov91Mn7RfZxyJQQgWgyS/3SZswlZ5iTI=
github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 h1:3f4uHLfWx4/WlnMPXGai03eoWAI+oGHJwr+5OXfxCr8=
github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
github.com/cli/oauth v0.8.0 h1:YTFgPXSTvvDUFti3tR4o6q7Oll2SnQ9ztLwCAn4/IOA=
github.com/cli/oauth v0.8.0/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4=
github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI=
@ -103,6 +105,8 @@ github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5y
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/fatih/camelcase v1.0.0 h1:hxNvNX/xYBp0ovncs8WyWZrOrpBNub/JfaMvbURyft8=
github.com/fatih/camelcase v1.0.0/go.mod h1:yN2Sb0lFhZJUdVvtELVWefmrXpuZESvPmqwoZc+/fpc=
github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys=
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
@ -182,6 +186,9 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/graph-gophers/graphql-go v0.0.0-20200622220639-c1d9693c95a6/go.mod h1:9CQHMSxwO4MprSdzoIEobiHpoLtHm77vfxsvsIN5Vuc=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q=
@ -273,8 +280,11 @@ github.com/muesli/reflow v0.2.1-0.20210502190812-c80126ec2ad5 h1:T+Fc6qGlSfM+z0J
github.com/muesli/reflow v0.2.1-0.20210502190812-c80126ec2ad5/go.mod h1:Xk+z4oIWdQqJzsxyjgl3P22oYZnHdZ8FFTHAQQt5BMQ=
github.com/muesli/termenv v0.8.1 h1:9q230czSP3DHVpkaPDXGp0TOfAwyjyYwXlUCQxQSaBk=
github.com/muesli/termenv v0.8.1/go.mod h1:kzt/D/4a88RoheZmwfqorY3A+tnsSMA9HJC/fQSFKo0=
github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38 h1:0FrBxrkJ0hVembTb/e4EU5Ml6vLcOusAqymmYISg5Uo=
github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38/go.mod h1:saF2fIVw4banK0H4+/EuqfFLpRnoy5S+ECwTOCcRcSU=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU=
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
@ -300,8 +310,12 @@ github.com/shurcooL/githubv4 v0.0.0-20200928013246-d292edc3691b h1:0/ecDXh/HTHRt
github.com/shurcooL/githubv4 v0.0.0-20200928013246-d292edc3691b/go.mod h1:hAF0iLZy4td2EX+/8Tw+4nodhlMrwN3HupfaXj3zkGo=
github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/sourcegraph/jsonrpc2 v0.1.0 h1:ohJHjZ+PcaLxDUjqk2NC3tIGsVa5bXThe1ZheSXOjuk=
github.com/sourcegraph/jsonrpc2 v0.1.0/go.mod h1:ZafdZgk/axhT1cvZAPOhw+95nz2I/Ra5qMlU4gTRwIo=
github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I=
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cobra v1.2.1 h1:+KmjbUw1hriSNMF55oPrkZcb27aECyrj8V2ytv7kWDw=
@ -344,16 +358,6 @@ go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo=
golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 h1:pLI5jrR7OSLijeIDcmRxNmw2api+jEfxLoykJVice/E=
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -396,7 +400,6 @@ golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
@ -497,8 +500,9 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b h1:qh4f65QIVFjq9eBURLEYWqaEXmOyqdUyiBSgaXWccWk=
golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210503060354-a79de5458b56 h1:b8jxX3zqjpqb2LklXPzKSGJhzyxCOZSz8ncv8Nv+y7w=
golang.org/x/term v0.0.0-20210503060354-a79de5458b56/go.mod h1:tfny5GFUkzUvx4ps4ajbZsCe5lw1metzhBm9T3x7oIY=

640
internal/api/api.go Normal file
View file

@ -0,0 +1,640 @@
package api
// For descriptions of service interfaces, see:
// - https://online.visualstudio.com/api/swagger (for visualstudio.com)
// - https://docs.github.com/en/rest/reference/repos (for api.github.com)
// - https://github.com/github/github/blob/master/app/api/codespaces.rb (for vscs_internal)
// TODO(adonovan): replace the last link with a public doc URL when available.
// TODO(adonovan): a possible reorganization would be to split this
// file into three internal packages, one per backend service, and to
// rename api.API to github.Client:
//
// - github.GetUser(github.Client)
// - github.GetRepository(Client)
// - github.ReadFile(Client, nwo, branch, path) // was GetCodespaceRepositoryContents
// - github.AuthorizedKeys(Client, user)
// - codespaces.Create(Client, user, repo, sku, branch, location)
// - codespaces.Delete(Client, user, token, name)
// - codespaces.Get(Client, token, owner, name)
// - codespaces.GetMachineTypes(Client, user, repo, branch, location)
// - codespaces.GetToken(Client, login, name)
// - codespaces.List(Client, user)
// - codespaces.Start(Client, token, codespace)
// - visualstudio.GetRegionLocation(http.Client) // no dependency on github
//
// This would make the meaning of each operation clearer.
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"strconv"
"strings"
"time"
"github.com/opentracing/opentracing-go"
)
const githubAPI = "https://api.github.com"
type API struct {
token string
client httpClient
githubAPI string
}
type httpClient interface {
Do(req *http.Request) (*http.Response, error)
}
func New(token string, httpClient httpClient) *API {
return &API{
token: token,
client: httpClient,
githubAPI: githubAPI,
}
}
type User struct {
Login string `json:"login"`
}
func (a *API) GetUser(ctx context.Context) (*User, error) {
req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/user", nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
a.setHeaders(req)
resp, err := a.do(ctx, req, "/user")
if err != nil {
return nil, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, jsonErrorResponse(b)
}
var response User
if err := json.Unmarshal(b, &response); err != nil {
return nil, fmt.Errorf("error unmarshaling response: %w", err)
}
return &response, nil
}
func jsonErrorResponse(b []byte) error {
var response struct {
Message string `json:"message"`
}
if err := json.Unmarshal(b, &response); err != nil {
return fmt.Errorf("error unmarshaling error response: %w", err)
}
return errors.New(response.Message)
}
type Repository struct {
ID int `json:"id"`
}
func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error) {
req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/repos/"+strings.ToLower(nwo), nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
a.setHeaders(req)
resp, err := a.do(ctx, req, "/repos/*")
if err != nil {
return nil, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, jsonErrorResponse(b)
}
var response Repository
if err := json.Unmarshal(b, &response); err != nil {
return nil, fmt.Errorf("error unmarshaling response: %w", err)
}
return &response, nil
}
type Codespace struct {
Name string `json:"name"`
GUID string `json:"guid"`
CreatedAt string `json:"created_at"`
LastUsedAt string `json:"last_used_at"`
Branch string `json:"branch"`
RepositoryName string `json:"repository_name"`
RepositoryNWO string `json:"repository_nwo"`
OwnerLogin string `json:"owner_login"`
Environment CodespaceEnvironment `json:"environment"`
}
type CodespaceEnvironment struct {
State string `json:"state"`
Connection CodespaceEnvironmentConnection `json:"connection"`
GitStatus CodespaceEnvironmentGitStatus `json:"gitStatus"`
}
type CodespaceEnvironmentGitStatus struct {
Ahead int `json:"ahead"`
Behind int `json:"behind"`
Branch string `json:"branch"`
Commit string `json:"commit"`
HasUnpushedChanges bool `json:"hasUnpushedChanges"`
HasUncommitedChanges bool `json:"hasUncommitedChanges"`
}
const (
CodespaceEnvironmentStateAvailable = "Available"
)
type CodespaceEnvironmentConnection struct {
SessionID string `json:"sessionId"`
SessionToken string `json:"sessionToken"`
RelayEndpoint string `json:"relayEndpoint"`
RelaySAS string `json:"relaySas"`
HostPublicKeys []string `json:"hostPublicKeys"`
}
func (a *API) ListCodespaces(ctx context.Context, user string) ([]*Codespace, error) {
req, err := http.NewRequest(
http.MethodGet, a.githubAPI+"/vscs_internal/user/"+user+"/codespaces", nil,
)
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
a.setHeaders(req)
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces")
if err != nil {
return nil, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, jsonErrorResponse(b)
}
var response struct {
Codespaces []*Codespace `json:"codespaces"`
}
if err := json.Unmarshal(b, &response); err != nil {
return nil, fmt.Errorf("error unmarshaling response: %w", err)
}
return response.Codespaces, nil
}
type getCodespaceTokenRequest struct {
MintRepositoryToken bool `json:"mint_repository_token"`
}
type getCodespaceTokenResponse struct {
RepositoryToken string `json:"repository_token"`
}
// ErrNotProvisioned is returned by GetCodespacesToken to indicate that the
// creation of a codespace is not yet complete and that the caller should try again.
var ErrNotProvisioned = errors.New("codespace not provisioned")
func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName string) (string, error) {
reqBody, err := json.Marshal(getCodespaceTokenRequest{true})
if err != nil {
return "", fmt.Errorf("error preparing request body: %w", err)
}
req, err := http.NewRequest(
http.MethodPost,
a.githubAPI+"/vscs_internal/user/"+ownerLogin+"/codespaces/"+codespaceName+"/token",
bytes.NewBuffer(reqBody),
)
if err != nil {
return "", fmt.Errorf("error creating request: %w", err)
}
a.setHeaders(req)
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*/token")
if err != nil {
return "", fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusUnprocessableEntity {
return "", ErrNotProvisioned
}
return "", jsonErrorResponse(b)
}
var response getCodespaceTokenResponse
if err := json.Unmarshal(b, &response); err != nil {
return "", fmt.Errorf("error unmarshaling response: %w", err)
}
return response.RepositoryToken, nil
}
func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) (*Codespace, error) {
req, err := http.NewRequest(
http.MethodGet,
a.githubAPI+"/vscs_internal/user/"+owner+"/codespaces/"+codespace,
nil,
)
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
// TODO: use a.setHeaders()
req.Header.Set("Authorization", "Bearer "+token)
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*")
if err != nil {
return nil, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, jsonErrorResponse(b)
}
var response Codespace
if err := json.Unmarshal(b, &response); err != nil {
return nil, fmt.Errorf("error unmarshaling response: %w", err)
}
return &response, nil
}
func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codespace) error {
req, err := http.NewRequest(
http.MethodPost,
a.githubAPI+"/vscs_internal/proxy/environments/"+codespace.GUID+"/start",
nil,
)
if err != nil {
return fmt.Errorf("error creating request: %w", err)
}
// TODO: use a.setHeaders()
req.Header.Set("Authorization", "Bearer "+token)
resp, err := a.do(ctx, req, "/vscs_internal/proxy/environments/*/start")
if err != nil {
return fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
// Error response may be a numeric code or a JSON {"message": "..."}.
if bytes.HasPrefix(b, []byte("{")) {
return jsonErrorResponse(b) // probably JSON
}
if len(b) > 100 {
b = append(b[:97], "..."...)
}
if strings.TrimSpace(string(b)) == "7" {
// Non-HTTP 200 with error code 7 (EnvironmentNotShutdown) is benign.
// Ignore it.
} else {
return fmt.Errorf("failed to start codespace: %s", b)
}
}
return nil
}
type getCodespaceRegionLocationResponse struct {
Current string `json:"current"`
}
func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) {
req, err := http.NewRequest(http.MethodGet, "https://online.visualstudio.com/api/v1/locations", nil)
if err != nil {
return "", fmt.Errorf("error creating request: %w", err)
}
resp, err := a.do(ctx, req, req.URL.String())
if err != nil {
return "", fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", jsonErrorResponse(b)
}
var response getCodespaceRegionLocationResponse
if err := json.Unmarshal(b, &response); err != nil {
return "", fmt.Errorf("error unmarshaling response: %w", err)
}
return response.Current, nil
}
type SKU struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
}
func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Repository, branch, location string) ([]*SKU, error) {
req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
q := req.URL.Query()
q.Add("location", location)
q.Add("ref", branch)
q.Add("repository_id", strconv.Itoa(repository.ID))
req.URL.RawQuery = q.Encode()
a.setHeaders(req)
resp, err := a.do(ctx, req, "/vscs_internal/user/*/skus")
if err != nil {
return nil, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, jsonErrorResponse(b)
}
var response struct {
SKUs []*SKU `json:"skus"`
}
if err := json.Unmarshal(b, &response); err != nil {
return nil, fmt.Errorf("error unmarshaling response: %w", err)
}
return response.SKUs, nil
}
// CreateCodespaceParams are the required parameters for provisioning a Codespace.
type CreateCodespaceParams struct {
User string
RepositoryID int
Branch, Machine, Location string
}
// CreateCodespace creates a codespace with the given parameters and returns a non-nil error if it
// fails to create.
func (a *API) CreateCodespace(ctx context.Context, params *CreateCodespaceParams) (*Codespace, error) {
codespace, err := a.startCreate(
ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location,
)
if err != errProvisioningInProgress {
return codespace, err
}
// errProvisioningInProgress indicates that codespace creation did not complete
// within the GitHub API RPC time limit (10s), so it continues asynchronously.
// We must poll the server to discover the outcome.
ctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
token, err := a.GetCodespaceToken(ctx, params.User, codespace.Name)
if err != nil {
if err == ErrNotProvisioned {
// Do nothing. We expect this to fail until the codespace is provisioned
continue
}
return nil, fmt.Errorf("failed to get codespace token: %w", err)
}
codespace, err = a.GetCodespace(ctx, token, params.User, codespace.Name)
if err != nil {
return nil, fmt.Errorf("failed to get codespace: %w", err)
}
return codespace, nil
}
}
}
type startCreateRequest struct {
RepositoryID int `json:"repository_id"`
Ref string `json:"ref"`
Location string `json:"location"`
SkuName string `json:"sku_name"`
}
var errProvisioningInProgress = errors.New("provisioning in progress")
// startCreate starts the creation of a codespace.
// It may return success or an error, or errProvisioningInProgress indicating that the operation
// did not complete before the GitHub API's time limit for RPCs (10s), in which case the caller
// must poll the server to learn the outcome.
func (a *API) startCreate(ctx context.Context, user string, repository int, sku, branch, location string) (*Codespace, error) {
requestBody, err := json.Marshal(startCreateRequest{repository, branch, location, sku})
if err != nil {
return nil, fmt.Errorf("error marshaling request: %w", err)
}
req, err := http.NewRequest(http.MethodPost, a.githubAPI+"/vscs_internal/user/"+user+"/codespaces", bytes.NewBuffer(requestBody))
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
a.setHeaders(req)
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces")
if err != nil {
return nil, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
switch {
case resp.StatusCode > http.StatusAccepted:
return nil, jsonErrorResponse(b)
case resp.StatusCode == http.StatusAccepted:
return nil, errProvisioningInProgress // RPC finished before result of creation known
}
var response Codespace
if err := json.Unmarshal(b, &response); err != nil {
return nil, fmt.Errorf("error unmarshaling response: %w", err)
}
return &response, nil
}
func (a *API) DeleteCodespace(ctx context.Context, user string, codespaceName string) error {
token, err := a.GetCodespaceToken(ctx, user, codespaceName)
if err != nil {
return fmt.Errorf("error getting codespace token: %w", err)
}
req, err := http.NewRequest(http.MethodDelete, a.githubAPI+"/vscs_internal/user/"+user+"/codespaces/"+codespaceName, nil)
if err != nil {
return fmt.Errorf("error creating request: %w", err)
}
// TODO: use a.setHeaders()
req.Header.Set("Authorization", "Bearer "+token)
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*")
if err != nil {
return fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode > http.StatusAccepted {
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error reading response body: %w", err)
}
return jsonErrorResponse(b)
}
return nil
}
type getCodespaceRepositoryContentsResponse struct {
Content string `json:"content"`
}
func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Codespace, path string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/repos/"+codespace.RepositoryNWO+"/contents/"+path, nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
q := req.URL.Query()
q.Add("ref", codespace.Branch)
req.URL.RawQuery = q.Encode()
a.setHeaders(req)
resp, err := a.do(ctx, req, "/repos/*/contents/*")
if err != nil {
return nil, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, nil
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, jsonErrorResponse(b)
}
var response getCodespaceRepositoryContentsResponse
if err := json.Unmarshal(b, &response); err != nil {
return nil, fmt.Errorf("error unmarshaling response: %w", err)
}
decoded, err := base64.StdEncoding.DecodeString(response.Content)
if err != nil {
return nil, fmt.Errorf("error decoding content: %w", err)
}
return decoded, nil
}
// AuthorizedKeys returns the public keys (in ~/.ssh/authorized_keys
// format) registered by the specified GitHub user.
func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
url := fmt.Sprintf("https://github.com/%s.keys", user)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := a.do(ctx, req, "/user.keys")
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server returned %s", resp.Status)
}
return b, nil
}
func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) {
// TODO(adonovan): use NewRequestWithContext(ctx) and drop ctx parameter.
span, ctx := opentracing.StartSpanFromContext(ctx, spanName)
defer span.Finish()
req = req.WithContext(ctx)
return a.client.Do(req)
}
func (a *API) setHeaders(req *http.Request) {
if a.token != "" {
req.Header.Set("Authorization", "Bearer "+a.token)
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
}

50
internal/api/api_test.go Normal file
View file

@ -0,0 +1,50 @@
package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func TestListCodespaces(t *testing.T) {
codespaces := []*Codespace{
{
Name: "testcodespace",
CreatedAt: "2021-08-09T10:10:24+02:00",
LastUsedAt: "2021-08-09T13:10:24+02:00",
},
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := struct {
Codespaces []*Codespace `json:"codespaces"`
}{
Codespaces: codespaces,
}
data, _ := json.Marshal(response)
fmt.Fprint(w, string(data))
}))
defer svr.Close()
api := API{
githubAPI: svr.URL,
client: &http.Client{},
token: "faketoken",
}
ctx := context.TODO()
codespaces, err := api.ListCodespaces(ctx, "testuser")
if err != nil {
t.Fatal(err)
}
if len(codespaces) != 1 {
t.Fatalf("expected 1 codespace, got %d", len(codespaces))
}
if codespaces[0].Name != "testcodespace" {
t.Fatalf("expected testcodespace, got %s", codespaces[0].Name)
}
}

View file

@ -0,0 +1,77 @@
package codespaces
import (
"context"
"errors"
"fmt"
"time"
"github.com/cli/cli/v2/internal/api"
"github.com/cli/cli/v2/internal/liveshare"
)
type logger interface {
Print(v ...interface{}) (int, error)
Println(v ...interface{}) (int, error)
}
func connectionReady(codespace *api.Codespace) bool {
return codespace.Environment.Connection.SessionID != "" &&
codespace.Environment.Connection.SessionToken != "" &&
codespace.Environment.Connection.RelayEndpoint != "" &&
codespace.Environment.Connection.RelaySAS != "" &&
codespace.Environment.State == api.CodespaceEnvironmentStateAvailable
}
type apiClient interface {
GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error)
GetCodespaceToken(ctx context.Context, user, codespace string) (string, error)
StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error
}
// ConnectToLiveshare waits for a Codespace to become running,
// and connects to it using a Live Share session.
func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) {
var startedCodespace bool
if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable {
startedCodespace = true
log.Print("Starting your codespace...")
if err := apiClient.StartCodespace(ctx, token, codespace); err != nil {
return nil, fmt.Errorf("error starting codespace: %w", err)
}
}
for retries := 0; !connectionReady(codespace); retries++ {
if retries > 1 {
if retries%2 == 0 {
log.Print(".")
}
time.Sleep(1 * time.Second)
}
if retries == 30 {
return nil, errors.New("timed out while waiting for the codespace to start")
}
var err error
codespace, err = apiClient.GetCodespace(ctx, token, userLogin, codespace.Name)
if err != nil {
return nil, fmt.Errorf("error getting codespace: %w", err)
}
}
if startedCodespace {
fmt.Print("\n")
}
log.Println("Connecting to your codespace...")
return liveshare.Connect(ctx, liveshare.Options{
SessionID: codespace.Environment.Connection.SessionID,
SessionToken: codespace.Environment.Connection.SessionToken,
RelaySAS: codespace.Environment.Connection.RelaySAS,
RelayEndpoint: codespace.Environment.Connection.RelayEndpoint,
HostPublicKeys: codespace.Environment.Connection.HostPublicKeys,
})
}

View file

@ -0,0 +1,90 @@
package codespaces
import (
"context"
"fmt"
"os"
"os/exec"
"strconv"
"strings"
)
// Shell runs an interactive secure shell over an existing
// port-forwarding session. It runs until the shell is terminated
// (including by cancellation of the context).
func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error {
cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs)
if err != nil {
return fmt.Errorf("failed to create ssh command: %w", err)
}
if usingCustomPort {
log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " "))
}
return cmd.Run()
}
// NewRemoteCommand returns an exec.Cmd that will securely run a shell
// command on the remote machine.
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) {
cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs)
return cmd, err
}
// newSSHCommand populates an exec.Cmd to run a command (or if blank,
// an interactive shell) over ssh.
func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) {
connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"}
// The ssh command syntax is: ssh [flags] user@host command [args...]
// There is no way to specify the user@host destination as a flag.
// Unfortunately, that means we need to know which user-provided words are
// SSH flags and which are command arguments so that we can place
// them before or after the destination, and that means we need to know all
// the flags and their arities.
cmdArgs, command, err := parseSSHArgs(cmdArgs)
if err != nil {
return nil, nil, err
}
cmdArgs = append(cmdArgs, connArgs...)
cmdArgs = append(cmdArgs, "-C") // Compression
cmdArgs = append(cmdArgs, dst) // user@host
if command != nil {
cmdArgs = append(cmdArgs, command...)
}
cmd := exec.CommandContext(ctx, "ssh", cmdArgs...)
cmd.Stdout = os.Stdout
cmd.Stdin = os.Stdin
cmd.Stderr = os.Stderr
return cmd, connArgs, nil
}
// parseSSHArgs parses SSH arguments into two distinct slices of flags and command.
// It returns an error if a unary flag is provided without an argument.
func parseSSHArgs(args []string) (cmdArgs, command []string, err error) {
for i := 0; i < len(args); i++ {
arg := args[i]
// if we've started parsing the command, set it to the rest of the args
if !strings.HasPrefix(arg, "-") {
command = args[i:]
break
}
cmdArgs = append(cmdArgs, arg)
if len(arg) == 2 && strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) {
if i++; i == len(args) {
return nil, nil, fmt.Errorf("ssh flag: %s requires an argument", arg)
}
cmdArgs = append(cmdArgs, args[i])
}
}
return cmdArgs, command, nil
}

View file

@ -0,0 +1,105 @@
package codespaces
import (
"fmt"
"testing"
)
func TestParseSSHArgs(t *testing.T) {
type testCase struct {
Args []string
ParsedArgs []string
Command []string
Error string
}
testCases := []testCase{
{}, // empty test case
{
Args: []string{"-X", "-Y"},
ParsedArgs: []string{"-X", "-Y"},
Command: nil,
},
{
Args: []string{"-X", "-Y", "-o", "someoption=test"},
ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"},
Command: nil,
},
{
Args: []string{"-X", "-Y", "-o", "someoption=test", "somecommand"},
ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"},
Command: []string{"somecommand"},
},
{
Args: []string{"-X", "-Y", "-o", "someoption=test", "echo", "test"},
ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"},
Command: []string{"echo", "test"},
},
{
Args: []string{"somecommand"},
ParsedArgs: []string{},
Command: []string{"somecommand"},
},
{
Args: []string{"echo", "test"},
ParsedArgs: []string{},
Command: []string{"echo", "test"},
},
{
Args: []string{"-v", "echo", "hello", "world"},
ParsedArgs: []string{"-v"},
Command: []string{"echo", "hello", "world"},
},
{
Args: []string{"-L", "-l"},
ParsedArgs: []string{"-L", "-l"},
Command: nil,
},
{
Args: []string{"-v", "echo", "-n", "test"},
ParsedArgs: []string{"-v"},
Command: []string{"echo", "-n", "test"},
},
{
Args: []string{"-v", "echo", "-b", "test"},
ParsedArgs: []string{"-v"},
Command: []string{"echo", "-b", "test"},
},
{
Args: []string{"-b"},
ParsedArgs: nil,
Command: nil,
Error: "ssh flag: -b requires an argument",
},
}
for _, tcase := range testCases {
args, command, err := parseSSHArgs(tcase.Args)
if tcase.Error != "" {
if err == nil {
t.Errorf("expected error and got nil: %#v", tcase)
}
if err.Error() != tcase.Error {
t.Errorf("error does not match expected error, got: '%s', expected: '%s'", err.Error(), tcase.Error)
}
continue
}
if err != nil {
t.Errorf("unexpected error: %v on test case: %#v", err, tcase)
continue
}
argsStr, parsedArgsStr := fmt.Sprintf("%s", args), fmt.Sprintf("%s", tcase.ParsedArgs)
if argsStr != parsedArgsStr {
t.Errorf("args do not match parsed args. got: '%s', expected: '%s'", argsStr, parsedArgsStr)
}
commandStr, parsedCommandStr := fmt.Sprintf("%s", command), fmt.Sprintf("%s", tcase.Command)
if commandStr != parsedCommandStr {
t.Errorf("command does not match parsed command. got: '%s', expected: '%s'", commandStr, parsedCommandStr)
}
}
}

View file

@ -0,0 +1,118 @@
package codespaces
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net"
"strings"
"time"
"github.com/cli/cli/v2/internal/api"
"github.com/cli/cli/v2/internal/liveshare"
)
// PostCreateStateStatus is a string value representing the different statuses a state can have.
type PostCreateStateStatus string
func (p PostCreateStateStatus) String() string {
return strings.Title(string(p))
}
const (
PostCreateStateRunning PostCreateStateStatus = "running"
PostCreateStateSuccess PostCreateStateStatus = "succeeded"
PostCreateStateFailed PostCreateStateStatus = "failed"
)
// PostCreateState is a combination of a state and status value that is captured
// during codespace creation.
type PostCreateState struct {
Name string `json:"name"`
Status PostCreateStateStatus `json:"status"`
}
// PollPostCreateStates watches for state changes in a codespace,
// and calls the supplied poller for each batch of state changes.
// It runs until it encounters an error, including cancellation of the context.
func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) {
token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name)
if err != nil {
return fmt.Errorf("getting codespace token: %w", err)
}
session, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("connect to Live Share: %w", err)
}
defer func() {
if closeErr := session.Close(); err == nil {
err = closeErr
}
}()
// Ensure local port is listening before client (getPostCreateOutput) connects.
listen, err := net.Listen("tcp", ":0") // arbitrary port
if err != nil {
return err
}
localPort := listen.Addr().(*net.TCPAddr).Port
log.Println("Fetching SSH Details...")
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)
}
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
}()
t := time.NewTicker(1 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-tunnelClosed:
return fmt.Errorf("connection failed: %w", err)
case <-t.C:
states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser)
if err != nil {
return fmt.Errorf("get post create output: %w", err)
}
poller(states)
}
}
}
func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) {
cmd, err := NewRemoteCommand(
ctx, tunnelPort, fmt.Sprintf("%s@localhost", user),
"cat /workspaces/.codespaces/shared/postCreateOutput.json",
)
if err != nil {
return nil, fmt.Errorf("remote command: %w", err)
}
stdout := new(bytes.Buffer)
cmd.Stdout = stdout
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("run command: %w", err)
}
var output struct {
Steps []PostCreateState `json:"steps"`
}
if err := json.Unmarshal(stdout.Bytes(), &output); err != nil {
return nil, fmt.Errorf("unmarshal output: %w", err)
}
return output.Steps, nil
}

View file

@ -0,0 +1,150 @@
// Package liveshare is a Go client library for the Visual Studio Live Share
// service, which provides collaborative, distibuted editing and debugging.
// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview.
//
// It provides the ability for a Go program to connect to a Live Share
// workspace (Connect), to expose a TCP port on a remote host
// (UpdateSharedVisibility), to start an SSH server listening on an
// exposed port (StartSSHServer), and to forward connections between
// the remote port and a local listening TCP port (ForwardToListener)
// or a local Go reader/writer (Forward).
package liveshare
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/url"
"strings"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
)
// An Options specifies Live Share connection parameters.
type Options struct {
SessionID string
SessionToken string // token for SSH session
RelaySAS string
RelayEndpoint string
HostPublicKeys []string
TLSConfig *tls.Config // (optional)
}
// uri returns a websocket URL for the specified options.
func (opts *Options) uri(action string) (string, error) {
if opts.SessionID == "" {
return "", errors.New("SessionID is required")
}
if opts.RelaySAS == "" {
return "", errors.New("RelaySAS is required")
}
if opts.RelayEndpoint == "" {
return "", errors.New("RelayEndpoint is required")
}
sas := url.QueryEscape(opts.RelaySAS)
uri := opts.RelayEndpoint
uri = strings.Replace(uri, "sb:", "wss:", -1)
uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1)
uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas
return uri, nil
}
// Connect connects to a Live Share workspace specified by the
// options, and returns a session representing the connection.
// The caller must call the session's Close method to end the session.
func Connect(ctx context.Context, opts Options) (*Session, error) {
uri, err := opts.uri("connect")
if err != nil {
return nil, err
}
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
defer span.Finish()
sock := newSocket(uri, opts.TLSConfig)
if err := sock.connect(ctx); err != nil {
return nil, fmt.Errorf("error connecting websocket: %w", err)
}
if opts.SessionToken == "" {
return nil, errors.New("SessionToken is required")
}
ssh := newSSHSession(opts.SessionToken, opts.HostPublicKeys, sock)
if err := ssh.connect(ctx); err != nil {
return nil, fmt.Errorf("error connecting to ssh session: %w", err)
}
rpc := newRPCClient(ssh)
rpc.connect(ctx)
args := joinWorkspaceArgs{
ID: opts.SessionID,
ConnectionMode: "local",
JoiningUserSessionToken: opts.SessionToken,
ClientCapabilities: clientCapabilities{
IsNonInteractive: false,
},
}
var result joinWorkspaceResult
if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil {
return nil, fmt.Errorf("error joining Live Share workspace: %w", err)
}
return &Session{ssh: ssh, rpc: rpc}, nil
}
type clientCapabilities struct {
IsNonInteractive bool `json:"isNonInteractive"`
}
type joinWorkspaceArgs struct {
ID string `json:"id"`
ConnectionMode string `json:"connectionMode"`
JoiningUserSessionToken string `json:"joiningUserSessionToken"`
ClientCapabilities clientCapabilities `json:"clientCapabilities"`
}
type joinWorkspaceResult struct {
SessionNumber int `json:"sessionNumber"`
}
// A channelID is an identifier for an exposed port on a remote
// container that may be used to open an SSH channel to it.
type channelID struct {
name, condition string
}
func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) {
type getStreamArgs struct {
StreamName string `json:"streamName"`
Condition string `json:"condition"`
}
args := getStreamArgs{
StreamName: id.name,
Condition: id.condition,
}
var streamID string
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
return nil, fmt.Errorf("error getting stream id: %w", err)
}
span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
defer span.Finish()
_ = ctx // ctx is not currently used
channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
if err != nil {
return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
}
go ssh.DiscardRequests(reqs)
requestType := fmt.Sprintf("stream-transport-%s", streamID)
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
return nil, fmt.Errorf("error sending channel request: %w", err)
}
return channel, nil
}

View file

@ -0,0 +1,72 @@
package liveshare
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
livesharetest "github.com/cli/cli/v2/internal/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestConnect(t *testing.T) {
opts := Options{
SessionID: "session-id",
SessionToken: "session-token",
RelaySAS: "relay-sas",
HostPublicKeys: []string{livesharetest.SSHPublicKey},
}
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
var joinWorkspaceReq joinWorkspaceArgs
if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil {
return nil, fmt.Errorf("error unmarshaling req: %w", err)
}
if joinWorkspaceReq.ID != opts.SessionID {
return nil, errors.New("connection session id does not match")
}
if joinWorkspaceReq.ConnectionMode != "local" {
return nil, errors.New("connection mode is not local")
}
if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken {
return nil, errors.New("connection user token does not match")
}
if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false {
return nil, errors.New("non interactive is not false")
}
return joinWorkspaceResult{1}, nil
}
server, err := livesharetest.NewServer(
livesharetest.WithPassword(opts.SessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
livesharetest.WithRelaySAS(opts.RelaySAS),
)
if err != nil {
t.Errorf("error creating Live Share server: %w", err)
}
defer server.Close()
opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https")
ctx := context.Background()
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
done := make(chan error)
go func() {
_, err := Connect(ctx, opts) // ignore session
done <- err
}()
select {
case err := <-server.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}

View file

@ -0,0 +1,56 @@
package liveshare
import (
"context"
"testing"
)
func TestBadOptions(t *testing.T) {
goodOptions := Options{
SessionID: "sess-id",
SessionToken: "sess-token",
RelaySAS: "sas",
RelayEndpoint: "endpoint",
}
opts := goodOptions
opts.SessionID = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.SessionToken = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.RelaySAS = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.RelayEndpoint = ""
checkBadOptions(t, opts)
opts = Options{}
checkBadOptions(t, opts)
}
func checkBadOptions(t *testing.T, opts Options) {
if _, err := Connect(context.Background(), opts); err == nil {
t.Errorf("Connect(%+v): no error", opts)
}
}
func TestOptionsURI(t *testing.T) {
opts := Options{
SessionID: "sess-id",
SessionToken: "sess-token",
RelaySAS: "sas",
RelayEndpoint: "sb://endpoint/.net/liveshare",
}
uri, err := opts.uri("connect")
if err != nil {
t.Fatal(err)
}
if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" {
t.Errorf("uri is not correct, got: '%v'", uri)
}
}

View file

@ -0,0 +1,162 @@
package liveshare
import (
"context"
"fmt"
"io"
"net"
"github.com/opentracing/opentracing-go"
)
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
// container to a local destination such as a network port or Go reader/writer.
type PortForwarder struct {
session *Session
name string
remotePort int
}
// NewPortForwarder returns a new PortForwarder for the specified
// remote port and Live Share session. The name describes the purpose
// of the remote port or service.
func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder {
return &PortForwarder{
session: session,
name: name,
remotePort: remotePort,
}
}
// ForwardToListener forwards traffic between the container's remote
// port and a local port, which must already be listening for
// connections. (Accepting a listener rather than a port number avoids
// races against other processes opening ports, and against a client
// connecting to the socket prematurely.)
//
// ForwardToListener accepts and handles connections on the local port
// until it encounters the first error, which may include context
// cancellation. Its error result is always non-nil. The caller is
// responsible for closing the listening port.
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
return err
}
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting
// stuck in case of concurrent or sequential errors.
select {
case errc <- err:
default:
}
}
go func() {
for {
conn, err := listen.Accept()
if err != nil {
sendError(err)
return
}
go func() {
if err := fwd.handleConnection(ctx, id, conn); err != nil {
sendError(err)
}
}()
}
}()
return awaitError(ctx, errc)
}
// Forward forwards traffic between the container's remote port and
// the specified read/write stream. On return, the stream is closed.
func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
conn.Close()
return err
}
// Create buffered channel so that send doesn't get stuck after context cancellation.
errc := make(chan error, 1)
go func() {
errc <- fwd.handleConnection(ctx, id, conn)
}()
return awaitError(ctx, errc)
}
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) {
id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort)
if err != nil {
err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err)
}
return id, err
}
func awaitError(ctx context.Context, errc <-chan error) error {
select {
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err() // canceled
}
}
// handleConnection handles forwarding for a single accepted connection, then closes it.
func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
defer span.Finish()
defer safeClose(conn, &err)
channel, err := fwd.session.openStreamingChannel(ctx, id)
if err != nil {
return fmt.Errorf("error opening streaming channel for new connection: %w", err)
}
// Ideally we would call safeClose again, but (*ssh.channel).Close
// appears to have a bug that causes it return io.EOF spuriously
// if its peer closed first; see github.com/golang/go/issues/38115.
defer func() {
closeErr := channel.Close()
if err == nil && closeErr != io.EOF {
err = closeErr
}
}()
// bi-directional copy of data.
errs := make(chan error, 2)
copyConn := func(w io.Writer, r io.Reader) {
_, err := io.Copy(w, r)
errs <- err
}
go copyConn(conn, channel)
go copyConn(channel, conn)
// Wait until context is cancelled or both copies are done.
// Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail.
// TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file?
for i := 0; ; {
select {
case <-ctx.Done():
return ctx.Err()
case <-errs:
i++
if i == 2 {
return nil
}
}
}
}
// safeClose reports the error (to *err) from closing the stream only
// if no other error was previously reported.
func safeClose(closer io.Closer, err *error) {
closeErr := closer.Close()
if *err == nil {
*err = closeErr
}
}

View file

@ -0,0 +1,95 @@
package liveshare
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
livesharetest "github.com/cli/cli/v2/internal/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestNewPortForwarder(t *testing.T) {
testServer, session, err := makeMockSession()
if err != nil {
t.Errorf("create mock client: %w", err)
}
defer testServer.Close()
pf := NewPortForwarder(session, "ssh", 80)
if pf == nil {
t.Error("port forwarder is nil")
}
}
func TestPortForwarderStart(t *testing.T) {
streamName, streamCondition := "stream-name", "stream-condition"
serverSharing := func(req *jsonrpc2.Request) (interface{}, error) {
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
}
getStream := func(req *jsonrpc2.Request) (interface{}, error) {
return "stream-id", nil
}
stream := bytes.NewBufferString("stream-data")
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.startSharing", serverSharing),
livesharetest.WithService("streamManager.getStream", getStream),
livesharetest.WithStream("stream-id", stream),
)
if err != nil {
t.Errorf("create mock session: %w", err)
}
defer testServer.Close()
listen, err := net.Listen("tcp", ":8000")
if err != nil {
t.Fatal(err)
}
defer listen.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan error)
go func() {
const name, remote = "ssh", 8000
done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen)
}()
go func() {
var conn net.Conn
retries := 0
for conn == nil && retries < 2 {
conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
time.Sleep(1 * time.Second)
}
if conn == nil {
done <- errors.New("failed to connect to forwarded port")
}
b := make([]byte, len("stream-data"))
if _, err := conn.Read(b); err != nil && err != io.EOF {
done <- fmt.Errorf("reading stream: %w", err)
}
if string(b) != "stream-data" {
done <- fmt.Errorf("stream data is not expected value, got: %s", string(b))
}
if _, err := conn.Write([]byte("new-data")); err != nil {
done <- fmt.Errorf("writing to stream: %w", err)
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}

41
internal/liveshare/rpc.go Normal file
View file

@ -0,0 +1,41 @@
package liveshare
import (
"context"
"fmt"
"io"
"github.com/opentracing/opentracing-go"
"github.com/sourcegraph/jsonrpc2"
)
type rpcClient struct {
*jsonrpc2.Conn
conn io.ReadWriteCloser
}
func newRPCClient(conn io.ReadWriteCloser) *rpcClient {
return &rpcClient{conn: conn}
}
func (r *rpcClient) connect(ctx context.Context) {
stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{})
r.Conn = jsonrpc2.NewConn(ctx, stream, nullHandler{})
}
func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, method)
defer span.Finish()
waiter, err := r.Conn.DispatchCall(ctx, method, args)
if err != nil {
return fmt.Errorf("error dispatching %q call: %w", method, err)
}
return waiter.Wait(ctx, result)
}
type nullHandler struct{}
func (nullHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
}

View file

@ -0,0 +1,99 @@
package liveshare
import (
"context"
"fmt"
"strconv"
)
// A Session represents the session between a connected Live Share client and server.
type Session struct {
ssh *sshSession
rpc *rpcClient
}
// Close should be called by users to clean up RPC and SSH resources whenever the session
// is no longer active.
func (s *Session) Close() error {
// Closing the RPC conn closes the underlying stream (SSH)
// So we only need to close once
if err := s.rpc.Close(); err != nil {
s.ssh.Close() // close SSH and ignore error
return fmt.Errorf("error while closing Live Share session: %w", err)
}
return nil
}
// Port describes a port exposed by the container.
type Port struct {
SourcePort int `json:"sourcePort"`
DestinationPort int `json:"destinationPort"`
SessionName string `json:"sessionName"`
StreamName string `json:"streamName"`
StreamCondition string `json:"streamCondition"`
BrowseURL string `json:"browseUrl"`
IsPublic bool `json:"isPublic"`
IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"`
HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"`
}
// startSharing tells the Live Share host to start sharing the specified port from the container.
// The sessionName describes the purpose of the remote port or service.
// It returns an identifier that can be used to open an SSH channel to the remote port.
func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) {
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
var response Port
if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil {
return channelID{}, err
}
return channelID{response.StreamName, response.StreamCondition}, nil
}
// GetSharedServers returns a description of each container port
// shared by a prior call to StartSharing by some client.
func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) {
var response []*Port
if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
return nil, err
}
return response, nil
}
// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly
// via the Browse URL
func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public bool) error {
if err := s.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil {
return err
}
return nil
}
// StartsSSHServer starts an SSH server in the container, installing sshd if necessary,
// and returns the port on which it listens and the user name clients should provide.
func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) {
var response struct {
Result bool `json:"result"`
ServerPort string `json:"serverPort"`
User string `json:"user"`
Message string `json:"message"`
}
if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil {
return 0, "", err
}
if !response.Result {
return 0, "", fmt.Errorf("failed to start server: %s", response.Message)
}
port, err := strconv.Atoi(response.ServerPort)
if err != nil {
return 0, "", fmt.Errorf("failed to parse port: %w", err)
}
return port, response.User, nil
}

View file

@ -0,0 +1,223 @@
package liveshare
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
livesharetest "github.com/cli/cli/v2/internal/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
opts = append(
opts,
livesharetest.WithPassword(sessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
)
testServer, err := livesharetest.NewServer(opts...)
if err != nil {
return nil, nil, fmt.Errorf("error creating server: %w", err)
}
session, err := Connect(context.Background(), Options{
SessionID: "session-id",
SessionToken: sessionToken,
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
RelaySAS: "relay-sas",
HostPublicKeys: []string{livesharetest.SSHPublicKey},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
})
if err != nil {
return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err)
}
return testServer, session, nil
}
func TestServerStartSharing(t *testing.T) {
serverPort, serverProtocol := 2222, "sshd"
startSharing := func(req *jsonrpc2.Request) (interface{}, error) {
var args []interface{}
if err := json.Unmarshal(*req.Params, &args); err != nil {
return nil, fmt.Errorf("error unmarshaling request: %w", err)
}
if len(args) < 3 {
return nil, errors.New("not enough arguments to start sharing")
}
if port, ok := args[0].(float64); !ok {
return nil, errors.New("port argument is not an int")
} else if port != float64(serverPort) {
return nil, errors.New("port does not match serverPort")
}
if protocol, ok := args[1].(string); !ok {
return nil, errors.New("protocol argument is not a string")
} else if protocol != serverProtocol {
return nil, errors.New("protocol does not match serverProtocol")
}
if browseURL, ok := args[2].(string); !ok {
return nil, errors.New("browse url is not a string")
} else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) {
return nil, errors.New("browseURL does not match expected")
}
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.startSharing", startSharing),
)
defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close()
if err != nil {
t.Errorf("error creating mock session: %w", err)
}
ctx := context.Background()
done := make(chan error)
go func() {
streamID, err := session.startSharing(ctx, serverProtocol, serverPort)
if err != nil {
done <- fmt.Errorf("error sharing server: %w", err)
}
if streamID.name == "" || streamID.condition == "" {
done <- errors.New("stream name or condition is blank")
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}
func TestServerGetSharedServers(t *testing.T) {
sharedServer := Port{
SourcePort: 2222,
StreamName: "stream-name",
StreamCondition: "stream-condition",
}
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
return []*Port{&sharedServer}, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
)
if err != nil {
t.Errorf("error creating mock session: %w", err)
}
defer testServer.Close()
ctx := context.Background()
done := make(chan error)
go func() {
ports, err := session.GetSharedServers(ctx)
if err != nil {
done <- fmt.Errorf("error getting shared servers: %w", err)
}
if len(ports) < 1 {
done <- errors.New("not enough ports returned")
}
if ports[0].SourcePort != sharedServer.SourcePort {
done <- errors.New("source port does not match")
}
if ports[0].StreamName != sharedServer.StreamName {
done <- errors.New("stream name does not match")
}
if ports[0].StreamCondition != sharedServer.StreamCondition {
done <- errors.New("stream condiion does not match")
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}
func TestServerUpdateSharedVisibility(t *testing.T) {
updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
}
if len(req) < 2 {
return nil, errors.New("request arguments is less than 2")
}
if port, ok := req[0].(float64); ok {
if port != 80.0 {
return nil, errors.New("port param is not expected value")
}
} else {
return nil, errors.New("port param is not a float64")
}
if public, ok := req[1].(bool); ok {
if public != true {
return nil, errors.New("pulic param is not expected value")
}
} else {
return nil, errors.New("public param is not a bool")
}
return nil, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility),
)
if err != nil {
t.Errorf("creating mock session: %w", err)
}
defer testServer.Close()
ctx := context.Background()
done := make(chan error)
go func() {
done <- session.UpdateSharedVisibility(ctx, 80, true)
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}
func TestInvalidHostKey(t *testing.T) {
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
opts := []livesharetest.ServerOption{
livesharetest.WithPassword(sessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
}
testServer, err := livesharetest.NewServer(opts...)
if err != nil {
t.Errorf("error creating server: %w", err)
}
_, err = Connect(context.Background(), Options{
SessionID: "session-id",
SessionToken: sessionToken,
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
RelaySAS: "relay-sas",
HostPublicKeys: []string{},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
})
if err == nil {
t.Error("expected invalid host key error, got: nil")
}
}

View file

@ -0,0 +1,100 @@
package liveshare
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"time"
"github.com/gorilla/websocket"
)
type socket struct {
addr string
tlsConfig *tls.Config
conn *websocket.Conn
reader io.Reader
}
func newSocket(uri string, tlsConfig *tls.Config) *socket {
return &socket{addr: uri, tlsConfig: tlsConfig}
}
func (s *socket) connect(ctx context.Context) error {
dialer := websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
TLSClientConfig: s.tlsConfig,
}
ws, _, err := dialer.Dial(s.addr, nil)
if err != nil {
return err
}
s.conn = ws
return nil
}
func (s *socket) Read(b []byte) (int, error) {
if s.reader == nil {
_, reader, err := s.conn.NextReader()
if err != nil {
return 0, err
}
s.reader = reader
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socket) Write(b []byte) (int, error) {
nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, err
}
bytesWritten, err := nextWriter.Write(b)
nextWriter.Close()
return bytesWritten, err
}
func (s *socket) Close() error {
return s.conn.Close()
}
func (s *socket) LocalAddr() net.Addr {
return s.conn.LocalAddr()
}
func (s *socket) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}
func (s *socket) SetDeadline(t time.Time) error {
if err := s.SetReadDeadline(t); err != nil {
return err
}
return s.SetWriteDeadline(t)
}
func (s *socket) SetReadDeadline(t time.Time) error {
return s.conn.SetReadDeadline(t)
}
func (s *socket) SetWriteDeadline(t time.Time) error {
return s.conn.SetWriteDeadline(t)
}

79
internal/liveshare/ssh.go Normal file
View file

@ -0,0 +1,79 @@
package liveshare
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"time"
"golang.org/x/crypto/ssh"
)
type sshSession struct {
*ssh.Session
token string
hostPublicKeys []string
socket net.Conn
conn ssh.Conn
reader io.Reader
writer io.Writer
}
func newSSHSession(token string, hostPublicKeys []string, socket net.Conn) *sshSession {
return &sshSession{token: token, hostPublicKeys: hostPublicKeys, socket: socket}
}
func (s *sshSession) connect(ctx context.Context) error {
clientConfig := ssh.ClientConfig{
User: "",
Auth: []ssh.AuthMethod{
ssh.Password(s.token),
},
HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"},
HostKeyCallback: func(hostname string, addr net.Addr, key ssh.PublicKey) error {
encodedKey := base64.StdEncoding.EncodeToString(key.Marshal())
for _, hpk := range s.hostPublicKeys {
if encodedKey == hpk {
return nil // we found a match for expected public key, safely return
}
}
return errors.New("invalid host public key")
},
Timeout: 10 * time.Second,
}
sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig)
if err != nil {
return fmt.Errorf("error creating ssh client connection: %w", err)
}
s.conn = sshClientConn
sshClient := ssh.NewClient(sshClientConn, chans, reqs)
s.Session, err = sshClient.NewSession()
if err != nil {
return fmt.Errorf("error creating ssh client session: %w", err)
}
s.reader, err = s.Session.StdoutPipe()
if err != nil {
return fmt.Errorf("error creating ssh session reader: %w", err)
}
s.writer, err = s.Session.StdinPipe()
if err != nil {
return fmt.Errorf("error creating ssh session writer: %w", err)
}
return nil
}
func (s *sshSession) Read(p []byte) (n int, err error) {
return s.reader.Read(p)
}
func (s *sshSession) Write(p []byte) (n int, err error) {
return s.writer.Write(p)
}

View file

@ -0,0 +1,334 @@
package livesharetest
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"github.com/gorilla/websocket"
"github.com/sourcegraph/jsonrpc2"
"golang.org/x/crypto/ssh"
)
const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIICXgIBAAKBgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yV
rCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhY
lR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3wIDAQAB
AoGBAI8UemkYoSM06gBCh5D1RHQt8eKNltzL7g9QSNfoXeZOC7+q+/TiZPcbqLp0
5lyOalu8b8Ym7J0rSE377Ypj13LyHMXS63e4wMiXv3qOl3GDhMLpypnJ8PwqR2b8
IijL2jrpQfLu6IYqlteA+7e9aEexJa1RRwxYIyq6pG1IYpbhAkEA9nKgtj3Z6ZDC
46IdqYzuUM9ZQdcw4AFr407+lub7tbWe5pYmaq3cT725IwLw081OAmnWJYFDMa/n
IPl9YcZSPQJBAMGOMbPs/YPkQAsgNdIUlFtK3o41OrrwJuTRTvv0DsbqDV0LKOiC
t8oAQQvjisH6Ew5OOhFyIFXtvZfzQMJppksCQQDWFd+cUICTUEise/Duj9maY3Uz
J99ySGnTbZTlu8PfJuXhg3/d3ihrMPG6A1z3cPqaSBxaOj8H07mhQHn1zNU1AkEA
hkl+SGPrO793g4CUdq2ahIA8SpO5rIsDoQtq7jlUq0MlhGFCv5Y5pydn+bSjx5MV
933kocf5kUSBntPBIWElYwJAZTm5ghu0JtSE6t3km0iuj7NGAQSdb6mD8+O7C3CP
FU3vi+4HlBysaT6IZ/HG+/dBsr4gYp4LGuS7DbaLuYw/uw==
-----END RSA PRIVATE KEY-----`
const SSHPublicKey = `AAAAB3NzaC1yc2EAAAADAQABAAAAgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yVrCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhYlR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3w==`
// Server represents a LiveShare relay host server.
type Server struct {
password string
services map[string]RPCHandleFunc
relaySAS string
streams map[string]io.ReadWriter
sshConfig *ssh.ServerConfig
httptestServer *httptest.Server
errCh chan error
}
// NewServer creates a new Server. ServerOptions can be passed to configure
// the SSH password, backing service, secrets and more.
func NewServer(opts ...ServerOption) (*Server, error) {
server := new(Server)
for _, o := range opts {
if err := o(server); err != nil {
return nil, err
}
}
server.sshConfig = &ssh.ServerConfig{
PasswordCallback: sshPasswordCallback(server.password),
}
privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey))
if err != nil {
return nil, fmt.Errorf("error parsing key: %w", err)
}
server.sshConfig.AddHostKey(privateKey)
server.errCh = make(chan error, 1)
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
return server, nil
}
// ServerOption is used to configure the Server.
type ServerOption func(*Server) error
// WithPassword configures the Server password for SSH.
func WithPassword(password string) ServerOption {
return func(s *Server) error {
s.password = password
return nil
}
}
// WithService accepts a mock RPC service for the Server to invoke.
func WithService(serviceName string, handler RPCHandleFunc) ServerOption {
return func(s *Server) error {
if s.services == nil {
s.services = make(map[string]RPCHandleFunc)
}
s.services[serviceName] = handler
return nil
}
}
// WithRelaySAS configures the relay SAS configuration key.
func WithRelaySAS(sas string) ServerOption {
return func(s *Server) error {
s.relaySAS = sas
return nil
}
}
// WithStream allows you to specify a mock data stream for the server.
func WithStream(name string, stream io.ReadWriter) ServerOption {
return func(s *Server) error {
if s.streams == nil {
s.streams = make(map[string]io.ReadWriter)
}
s.streams[name] = stream
return nil
}
}
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if string(password) == serverPassword {
return nil, nil
}
return nil, errors.New("password rejected")
}
}
// Close closes the underlying httptest Server.
func (s *Server) Close() {
s.httptestServer.Close()
}
// URL returns the httptest Server url.
func (s *Server) URL() string {
return s.httptestServer.URL
}
func (s *Server) Err() <-chan error {
return s.errCh
}
var upgrader = websocket.Upgrader{}
func makeConnection(server *Server) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if server.relaySAS != "" {
// validate the sas key
sasParam := req.URL.Query().Get("sb-hc-token")
if sasParam != server.relaySAS {
sendError(server.errCh, errors.New("error validating sas"))
return
}
}
c, err := upgrader.Upgrade(w, req, nil)
if err != nil {
sendError(server.errCh, fmt.Errorf("error upgrading connection: %w", err))
return
}
defer func() {
if err := c.Close(); err != nil {
sendError(server.errCh, err)
}
}()
socketConn := newSocketConn(c)
_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
if err != nil {
sendError(server.errCh, fmt.Errorf("error creating new ssh conn: %w", err))
return
}
go ssh.DiscardRequests(reqs)
if err := handleChannels(ctx, server, chans); err != nil {
sendError(server.errCh, err)
}
}
}
// sendError does a non-blocking send of the error to the err channel.
func sendError(errc chan<- error, err error) {
select {
case errc <- err:
default:
// channel is blocked with a previous error, so we ignore
// this current error
}
}
// awaitError waits for the context to finish and returns its error (if any).
// It also waits for an err to come through the err channel.
func awaitError(ctx context.Context, errc <-chan error) error {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errc:
return err
}
}
// handleChannels services the sshChannels channel. For each SSH channel received
// it creates a go routine to service the channel's requests. It returns on the first
// error encountered.
func handleChannels(ctx context.Context, server *Server, sshChannels <-chan ssh.NewChannel) error {
errc := make(chan error, 1)
go func() {
for sshCh := range sshChannels {
ch, reqs, err := sshCh.Accept()
if err != nil {
sendError(errc, fmt.Errorf("failed to accept channel: %w", err))
return
}
go func() {
if err := handleRequests(ctx, server, ch, reqs); err != nil {
sendError(errc, fmt.Errorf("failed to handle requests: %w", err))
}
}()
handleChannel(server, ch)
}
}()
return awaitError(ctx, errc)
}
// handleRequests services the SSH channel requests channel. It replies to requests and
// when stream transport requests are encountered, creates a go routine to create a
// bi-directional data stream between the channel and server stream. It returns on the first error
// encountered.
func handleRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) error {
errc := make(chan error, 1)
go func() {
for req := range reqs {
if req.WantReply {
if err := req.Reply(true, nil); err != nil {
sendError(errc, fmt.Errorf("error replying to channel request: %w", err))
return
}
}
if strings.HasPrefix(req.Type, "stream-transport") {
go func() {
if err := forwardStream(ctx, server, req.Type, channel); err != nil {
sendError(errc, fmt.Errorf("failed to forward stream: %w", err))
}
}()
}
}
}()
return awaitError(ctx, errc)
}
// concurrentStream is a concurrency safe io.ReadWriter.
type concurrentStream struct {
sync.RWMutex
stream io.ReadWriter
}
func newConcurrentStream(rw io.ReadWriter) *concurrentStream {
return &concurrentStream{stream: rw}
}
func (cs *concurrentStream) Read(b []byte) (int, error) {
cs.RLock()
defer cs.RUnlock()
return cs.stream.Read(b)
}
func (cs *concurrentStream) Write(b []byte) (int, error) {
cs.Lock()
defer cs.Unlock()
return cs.stream.Write(b)
}
// forwardStream does a bi-directional copy of the stream <-> with the SSH channel. The io.Copy
// runs until an error is encountered.
func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) (err error) {
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
stream, found := server.streams[simpleStreamName]
if !found {
return fmt.Errorf("stream '%s' not found", simpleStreamName)
}
defer func() {
if closeErr := channel.Close(); err == nil && closeErr != io.EOF {
err = closeErr
}
}()
errc := make(chan error, 2)
copy := func(dst io.Writer, src io.Reader) {
if _, err := io.Copy(dst, src); err != nil {
errc <- err
}
}
csStream := newConcurrentStream(stream)
go copy(csStream, channel)
go copy(channel, csStream)
return awaitError(ctx, errc)
}
func handleChannel(server *Server, channel ssh.Channel) {
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server))
}
type RPCHandleFunc func(req *jsonrpc2.Request) (interface{}, error)
type rpcHandler struct {
server *Server
}
func newRPCHandler(server *Server) *rpcHandler {
return &rpcHandler{server}
}
// Handle satisfies the jsonrpc2 pkg handler interface. It tries to find a mocked
// RPC service method and if found, it invokes the handler and replies to the request.
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
handler, found := r.server.services[req.Method]
if !found {
sendError(r.server.errCh, fmt.Errorf("RPC Method: '%s' not serviced", req.Method))
return
}
result, err := handler(req)
if err != nil {
sendError(r.server.errCh, fmt.Errorf("error handling: '%s': %w", req.Method, err))
return
}
if err := conn.Reply(ctx, req.ID, result); err != nil {
sendError(r.server.errCh, fmt.Errorf("error replying: %w", err))
}
}

View file

@ -0,0 +1,77 @@
package livesharetest
import (
"fmt"
"io"
"sync"
"time"
"github.com/gorilla/websocket"
)
type socketConn struct {
*websocket.Conn
reader io.Reader
writeMutex sync.Mutex
readMutex sync.Mutex
}
func newSocketConn(conn *websocket.Conn) *socketConn {
return &socketConn{Conn: conn}
}
func (s *socketConn) Read(b []byte) (int, error) {
s.readMutex.Lock()
defer s.readMutex.Unlock()
if s.reader == nil {
msgType, r, err := s.Conn.NextReader()
if err != nil {
return 0, fmt.Errorf("error getting next reader: %w", err)
}
if msgType != websocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type")
}
s.reader = r
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socketConn) Write(b []byte) (int, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, fmt.Errorf("error getting next writer: %w", err)
}
n, err := w.Write(b)
if err != nil {
return 0, fmt.Errorf("error writing: %w", err)
}
if err := w.Close(); err != nil {
return 0, fmt.Errorf("error closing writer: %w", err)
}
return n, nil
}
func (s *socketConn) SetDeadline(deadline time.Time) error {
if err := s.Conn.SetReadDeadline(deadline); err != nil {
return err
}
return s.Conn.SetWriteDeadline(deadline)
}

View file

@ -2,8 +2,12 @@ package root
import (
"net/http"
"sync"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/cmd/ghcs"
"github.com/cli/cli/v2/cmd/ghcs/output"
ghcsApi "github.com/cli/cli/v2/internal/api"
actionsCmd "github.com/cli/cli/v2/pkg/cmd/actions"
aliasCmd "github.com/cli/cli/v2/pkg/cmd/alias"
apiCmd "github.com/cli/cli/v2/pkg/cmd/api"
@ -78,6 +82,7 @@ func NewCmdRoot(f *cmdutil.Factory, version, buildDate string) *cobra.Command {
cmd.AddCommand(extensionCmd.NewCmdExtension(f))
cmd.AddCommand(secretCmd.NewCmdSecret(f))
cmd.AddCommand(sshKeyCmd.NewCmdSSHKey(f))
cmd.AddCommand(newCodespaceCmd(f))
// the `api` command should not inherit any extra HTTP headers
bareHTTPCmdFactory := *f
@ -121,3 +126,39 @@ func bareHTTPClient(f *cmdutil.Factory, version string) func() (*http.Client, er
return factory.NewHTTPClient(f.IOStreams, cfg, version, false)
}
}
func newCodespaceCmd(f *cmdutil.Factory) *cobra.Command {
cmd := ghcs.NewRootCmd(ghcs.NewApp(
output.NewLogger(f.IOStreams.Out, f.IOStreams.ErrOut, !f.IOStreams.IsStdoutTTY()),
ghcsApi.New("", &lazyLoadedHTTPClient{factory: f}),
))
cmd.Use = "codespace"
cmd.Aliases = []string{"cs"}
cmd.Hidden = true
return cmd
}
type lazyLoadedHTTPClient struct {
factory *cmdutil.Factory
httpClientMu sync.RWMutex // guards httpClient
httpClient *http.Client
}
func (l *lazyLoadedHTTPClient) Do(req *http.Request) (*http.Response, error) {
l.httpClientMu.RLock()
httpClient := l.httpClient
l.httpClientMu.RUnlock()
if httpClient == nil {
l.httpClientMu.Lock()
defer l.httpClientMu.Unlock()
var err error
l.httpClient, err = l.factory.HttpClient()
if err != nil {
return nil, err
}
}
return l.httpClient.Do(req)
}