commit
877ad22da6
38 changed files with 5227 additions and 14 deletions
65
cmd/ghcs/code.go
Normal file
65
cmd/ghcs/code.go
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
package ghcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newCodeCmd(app *App) *cobra.Command {
|
||||
var (
|
||||
codespace string
|
||||
useInsiders bool
|
||||
)
|
||||
|
||||
codeCmd := &cobra.Command{
|
||||
Use: "code",
|
||||
Short: "Open a codespace in VS Code",
|
||||
Args: noArgsConstraint,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return app.VSCode(cmd.Context(), codespace, useInsiders)
|
||||
},
|
||||
}
|
||||
|
||||
codeCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace")
|
||||
codeCmd.Flags().BoolVar(&useInsiders, "insiders", false, "Use the insiders version of VS Code")
|
||||
|
||||
return codeCmd
|
||||
}
|
||||
|
||||
// VSCode opens a codespace in the local VS VSCode application.
|
||||
func (a *App) VSCode(ctx context.Context, codespaceName string, useInsiders bool) error {
|
||||
user, err := a.apiClient.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
if codespaceName == "" {
|
||||
codespace, err := chooseCodespace(ctx, a.apiClient, user)
|
||||
if err != nil {
|
||||
if err == errNoCodespaces {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("error choosing codespace: %w", err)
|
||||
}
|
||||
codespaceName = codespace.Name
|
||||
}
|
||||
|
||||
url := vscodeProtocolURL(codespaceName, useInsiders)
|
||||
if err := open.Run(url); err != nil {
|
||||
return fmt.Errorf("error opening vscode URL %s: %s. (Is VS Code installed?)", url, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func vscodeProtocolURL(codespaceName string, useInsiders bool) string {
|
||||
application := "vscode"
|
||||
if useInsiders {
|
||||
application = "vscode-insiders"
|
||||
}
|
||||
return fmt.Sprintf("%s://github.codespaces/connect?name=%s", application, url.QueryEscape(codespaceName))
|
||||
}
|
||||
185
cmd/ghcs/common.go
Normal file
185
cmd/ghcs/common.go
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
package ghcs
|
||||
|
||||
// This file defines functions common to the entire ghcs command set.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
|
||||
"github.com/AlecAivazis/survey/v2"
|
||||
"github.com/AlecAivazis/survey/v2/terminal"
|
||||
"github.com/cli/cli/v2/cmd/ghcs/output"
|
||||
"github.com/cli/cli/v2/internal/api"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
apiClient apiClient
|
||||
logger *output.Logger
|
||||
}
|
||||
|
||||
func NewApp(logger *output.Logger, apiClient apiClient) *App {
|
||||
return &App{
|
||||
apiClient: apiClient,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient
|
||||
type apiClient interface {
|
||||
GetUser(ctx context.Context) (*api.User, error)
|
||||
GetCodespaceToken(ctx context.Context, user, name string) (string, error)
|
||||
GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error)
|
||||
ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error)
|
||||
DeleteCodespace(ctx context.Context, user, name string) error
|
||||
StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error
|
||||
CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
|
||||
GetRepository(ctx context.Context, nwo string) (*api.Repository, error)
|
||||
AuthorizedKeys(ctx context.Context, user string) ([]byte, error)
|
||||
GetCodespaceRegionLocation(ctx context.Context) (string, error)
|
||||
GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch, location string) ([]*api.SKU, error)
|
||||
GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
|
||||
}
|
||||
|
||||
var errNoCodespaces = errors.New("you have no codespaces")
|
||||
|
||||
func chooseCodespace(ctx context.Context, apiClient apiClient, user *api.User) (*api.Codespace, error) {
|
||||
codespaces, err := apiClient.ListCodespaces(ctx, user.Login)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting codespaces: %w", err)
|
||||
}
|
||||
return chooseCodespaceFromList(ctx, codespaces)
|
||||
}
|
||||
|
||||
func chooseCodespaceFromList(ctx context.Context, codespaces []*api.Codespace) (*api.Codespace, error) {
|
||||
if len(codespaces) == 0 {
|
||||
return nil, errNoCodespaces
|
||||
}
|
||||
|
||||
sort.Slice(codespaces, func(i, j int) bool {
|
||||
return codespaces[i].CreatedAt > codespaces[j].CreatedAt
|
||||
})
|
||||
|
||||
codespacesByName := make(map[string]*api.Codespace)
|
||||
codespacesNames := make([]string, 0, len(codespaces))
|
||||
for _, codespace := range codespaces {
|
||||
codespacesByName[codespace.Name] = codespace
|
||||
codespacesNames = append(codespacesNames, codespace.Name)
|
||||
}
|
||||
|
||||
sshSurvey := []*survey.Question{
|
||||
{
|
||||
Name: "codespace",
|
||||
Prompt: &survey.Select{
|
||||
Message: "Choose codespace:",
|
||||
Options: codespacesNames,
|
||||
Default: codespacesNames[0],
|
||||
},
|
||||
Validate: survey.Required,
|
||||
},
|
||||
}
|
||||
|
||||
var answers struct {
|
||||
Codespace string
|
||||
}
|
||||
if err := ask(sshSurvey, &answers); err != nil {
|
||||
return nil, fmt.Errorf("error getting answers: %w", err)
|
||||
}
|
||||
|
||||
codespace := codespacesByName[answers.Codespace]
|
||||
return codespace, nil
|
||||
}
|
||||
|
||||
// getOrChooseCodespace prompts the user to choose a codespace if the codespaceName is empty.
|
||||
// It then fetches the codespace token and the codespace record.
|
||||
func getOrChooseCodespace(ctx context.Context, apiClient apiClient, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) {
|
||||
if codespaceName == "" {
|
||||
codespace, err = chooseCodespace(ctx, apiClient, user)
|
||||
if err != nil {
|
||||
if err == errNoCodespaces {
|
||||
return nil, "", err
|
||||
}
|
||||
return nil, "", fmt.Errorf("choosing codespace: %w", err)
|
||||
}
|
||||
codespaceName = codespace.Name
|
||||
|
||||
token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("getting codespace token: %w", err)
|
||||
}
|
||||
} else {
|
||||
token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("getting codespace token for given codespace: %w", err)
|
||||
}
|
||||
|
||||
codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("getting full codespace details: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return codespace, token, nil
|
||||
}
|
||||
|
||||
func safeClose(closer io.Closer, err *error) {
|
||||
if closeErr := closer.Close(); *err == nil {
|
||||
*err = closeErr
|
||||
}
|
||||
}
|
||||
|
||||
// hasTTY indicates whether the process connected to a terminal.
|
||||
// It is not portable to assume stdin/stdout are fds 0 and 1.
|
||||
var hasTTY = term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
|
||||
|
||||
// ask asks survey questions on the terminal, using standard options.
|
||||
// It fails unless hasTTY, but ideally callers should avoid calling it in that case.
|
||||
func ask(qs []*survey.Question, response interface{}) error {
|
||||
if !hasTTY {
|
||||
return fmt.Errorf("no terminal")
|
||||
}
|
||||
err := survey.Ask(qs, response, survey.WithShowCursor(true))
|
||||
// The survey package temporarily clears the terminal's ISIG mode bit
|
||||
// (see tcsetattr(3)) so the QUIT button (Ctrl-C) is reported as
|
||||
// ASCII \x03 (ETX) instead of delivering SIGINT to the application.
|
||||
// So we have to serve ourselves the SIGINT.
|
||||
//
|
||||
// https://github.com/AlecAivazis/survey/#why-isnt-ctrl-c-working
|
||||
if err == terminal.InterruptErr {
|
||||
self, _ := os.FindProcess(os.Getpid())
|
||||
_ = self.Signal(os.Interrupt) // assumes POSIX
|
||||
|
||||
// Suspend the goroutine, to avoid a race between
|
||||
// return from main and async delivery of INT signal.
|
||||
select {}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// checkAuthorizedKeys reports an error if the user has not registered any SSH keys;
|
||||
// see https://github.com/cli/cli/v2/issues/166#issuecomment-921769703.
|
||||
// The check is not required for security but it improves the error message.
|
||||
func checkAuthorizedKeys(ctx context.Context, client apiClient, user string) error {
|
||||
keys, err := client.AuthorizedKeys(ctx, user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err)
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("user %s has no GitHub-authorized SSH keys", user)
|
||||
}
|
||||
return nil // success
|
||||
}
|
||||
|
||||
var ErrTooManyArgs = errors.New("the command accepts no arguments")
|
||||
|
||||
func noArgsConstraint(cmd *cobra.Command, args []string) error {
|
||||
if len(args) > 0 {
|
||||
return ErrTooManyArgs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
297
cmd/ghcs/create.go
Normal file
297
cmd/ghcs/create.go
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
package ghcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/AlecAivazis/survey/v2"
|
||||
"github.com/cli/cli/v2/cmd/ghcs/output"
|
||||
"github.com/cli/cli/v2/internal/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/fatih/camelcase"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type createOptions struct {
|
||||
repo string
|
||||
branch string
|
||||
machine string
|
||||
showStatus bool
|
||||
}
|
||||
|
||||
func newCreateCmd(app *App) *cobra.Command {
|
||||
opts := createOptions{}
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a codespace",
|
||||
Args: noArgsConstraint,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return app.Create(cmd.Context(), opts)
|
||||
},
|
||||
}
|
||||
|
||||
createCmd.Flags().StringVarP(&opts.repo, "repo", "r", "", "repository name with owner: user/repo")
|
||||
createCmd.Flags().StringVarP(&opts.branch, "branch", "b", "", "repository branch")
|
||||
createCmd.Flags().StringVarP(&opts.machine, "machine", "m", "", "hardware specifications for the VM")
|
||||
createCmd.Flags().BoolVarP(&opts.showStatus, "status", "s", false, "show status of post-create command and dotfiles")
|
||||
|
||||
return createCmd
|
||||
}
|
||||
|
||||
// Create creates a new Codespace
|
||||
func (a *App) Create(ctx context.Context, opts createOptions) error {
|
||||
locationCh := getLocation(ctx, a.apiClient)
|
||||
userCh := getUser(ctx, a.apiClient)
|
||||
|
||||
repo, err := getRepoName(opts.repo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting repository name: %w", err)
|
||||
}
|
||||
branch, err := getBranchName(opts.branch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting branch name: %w", err)
|
||||
}
|
||||
|
||||
repository, err := a.apiClient.GetRepository(ctx, repo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting repository: %w", err)
|
||||
}
|
||||
|
||||
locationResult := <-locationCh
|
||||
if locationResult.Err != nil {
|
||||
return fmt.Errorf("error getting codespace region location: %w", locationResult.Err)
|
||||
}
|
||||
|
||||
userResult := <-userCh
|
||||
if userResult.Err != nil {
|
||||
return fmt.Errorf("error getting codespace user: %w", userResult.Err)
|
||||
}
|
||||
|
||||
machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, a.apiClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting machine type: %w", err)
|
||||
}
|
||||
if machine == "" {
|
||||
return errors.New("there are no available machine types for this repository")
|
||||
}
|
||||
|
||||
a.logger.Print("Creating your codespace...")
|
||||
codespace, err := a.apiClient.CreateCodespace(ctx, &api.CreateCodespaceParams{
|
||||
User: userResult.User.Login,
|
||||
RepositoryID: repository.ID,
|
||||
Branch: branch,
|
||||
Machine: machine,
|
||||
Location: locationResult.Location,
|
||||
})
|
||||
a.logger.Print("\n")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating codespace: %w", err)
|
||||
}
|
||||
|
||||
if opts.showStatus {
|
||||
if err := showStatus(ctx, a.logger, a.apiClient, userResult.User, codespace); err != nil {
|
||||
return fmt.Errorf("show status: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
a.logger.Printf("Codespace created: ")
|
||||
|
||||
fmt.Fprintln(os.Stdout, codespace.Name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// showStatus polls the codespace for a list of post create states and their status. It will keep polling
|
||||
// until all states have finished. Once all states have finished, we poll once more to check if any new
|
||||
// states have been introduced and stop polling otherwise.
|
||||
func showStatus(ctx context.Context, log *output.Logger, apiClient apiClient, user *api.User, codespace *api.Codespace) error {
|
||||
var lastState codespaces.PostCreateState
|
||||
var breakNextState bool
|
||||
|
||||
finishedStates := make(map[string]bool)
|
||||
ctx, stopPolling := context.WithCancel(ctx)
|
||||
defer stopPolling()
|
||||
|
||||
poller := func(states []codespaces.PostCreateState) {
|
||||
var inProgress bool
|
||||
for _, state := range states {
|
||||
if _, found := finishedStates[state.Name]; found {
|
||||
continue // skip this state as we've processed it already
|
||||
}
|
||||
|
||||
if state.Name != lastState.Name {
|
||||
log.Print(state.Name)
|
||||
|
||||
if state.Status == codespaces.PostCreateStateRunning {
|
||||
inProgress = true
|
||||
lastState = state
|
||||
log.Print("...")
|
||||
break
|
||||
}
|
||||
|
||||
finishedStates[state.Name] = true
|
||||
log.Println("..." + state.Status)
|
||||
} else {
|
||||
if state.Status == codespaces.PostCreateStateRunning {
|
||||
inProgress = true
|
||||
log.Print(".")
|
||||
break
|
||||
}
|
||||
|
||||
finishedStates[state.Name] = true
|
||||
log.Println(state.Status)
|
||||
lastState = codespaces.PostCreateState{} // reset the value
|
||||
}
|
||||
}
|
||||
|
||||
if !inProgress {
|
||||
if breakNextState {
|
||||
stopPolling()
|
||||
return
|
||||
}
|
||||
breakNextState = true
|
||||
}
|
||||
}
|
||||
|
||||
err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace, poller)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) && breakNextState {
|
||||
return nil // we cancelled the context to stop polling, we can ignore the error
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to poll state changes from codespace: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type getUserResult struct {
|
||||
User *api.User
|
||||
Err error
|
||||
}
|
||||
|
||||
// getUser fetches the user record associated with the GITHUB_TOKEN
|
||||
func getUser(ctx context.Context, apiClient apiClient) <-chan getUserResult {
|
||||
ch := make(chan getUserResult, 1)
|
||||
go func() {
|
||||
user, err := apiClient.GetUser(ctx)
|
||||
ch <- getUserResult{user, err}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
type locationResult struct {
|
||||
Location string
|
||||
Err error
|
||||
}
|
||||
|
||||
// getLocation fetches the closest Codespace datacenter region/location to the user.
|
||||
func getLocation(ctx context.Context, apiClient apiClient) <-chan locationResult {
|
||||
ch := make(chan locationResult, 1)
|
||||
go func() {
|
||||
location, err := apiClient.GetCodespaceRegionLocation(ctx)
|
||||
ch <- locationResult{location, err}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
// getRepoName prompts the user for the name of the repository, or returns the repository if non-empty.
|
||||
func getRepoName(repo string) (string, error) {
|
||||
if repo != "" {
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
repoSurvey := []*survey.Question{
|
||||
{
|
||||
Name: "repository",
|
||||
Prompt: &survey.Input{Message: "Repository:"},
|
||||
Validate: survey.Required,
|
||||
},
|
||||
}
|
||||
err := ask(repoSurvey, &repo)
|
||||
return repo, err
|
||||
}
|
||||
|
||||
// getBranchName prompts the user for the name of the branch, or returns the branch if non-empty.
|
||||
func getBranchName(branch string) (string, error) {
|
||||
if branch != "" {
|
||||
return branch, nil
|
||||
}
|
||||
|
||||
branchSurvey := []*survey.Question{
|
||||
{
|
||||
Name: "branch",
|
||||
Prompt: &survey.Input{Message: "Branch:"},
|
||||
Validate: survey.Required,
|
||||
},
|
||||
}
|
||||
err := ask(branchSurvey, &branch)
|
||||
return branch, err
|
||||
}
|
||||
|
||||
// getMachineName prompts the user to select the machine type, or validates the machine if non-empty.
|
||||
func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient apiClient) (string, error) {
|
||||
skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error requesting machine instance types: %w", err)
|
||||
}
|
||||
|
||||
// if user supplied a machine type, it must be valid
|
||||
// if no machine type was supplied, we don't error if there are no machine types for the current repo
|
||||
if machine != "" {
|
||||
for _, sku := range skus {
|
||||
if machine == sku.Name {
|
||||
return machine, nil
|
||||
}
|
||||
}
|
||||
|
||||
availableSKUs := make([]string, len(skus))
|
||||
for i := 0; i < len(skus); i++ {
|
||||
availableSKUs[i] = skus[i].Name
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("there is no such machine for the repository: %s\nAvailable machines: %v", machine, availableSKUs)
|
||||
} else if len(skus) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if len(skus) == 1 {
|
||||
return skus[0].Name, nil // VS Code does not prompt for SKU if there is only one, this makes us consistent with that behavior
|
||||
}
|
||||
|
||||
skuNames := make([]string, 0, len(skus))
|
||||
skuByName := make(map[string]*api.SKU)
|
||||
for _, sku := range skus {
|
||||
nameParts := camelcase.Split(sku.Name)
|
||||
machineName := strings.Title(strings.ToLower(nameParts[0]))
|
||||
skuName := fmt.Sprintf("%s - %s", machineName, sku.DisplayName)
|
||||
skuNames = append(skuNames, skuName)
|
||||
skuByName[skuName] = sku
|
||||
}
|
||||
|
||||
skuSurvey := []*survey.Question{
|
||||
{
|
||||
Name: "sku",
|
||||
Prompt: &survey.Select{
|
||||
Message: "Choose Machine Type:",
|
||||
Options: skuNames,
|
||||
Default: skuNames[0],
|
||||
},
|
||||
Validate: survey.Required,
|
||||
},
|
||||
}
|
||||
|
||||
var skuAnswers struct{ SKU string }
|
||||
if err := ask(skuSurvey, &skuAnswers); err != nil {
|
||||
return "", fmt.Errorf("error getting SKU: %w", err)
|
||||
}
|
||||
|
||||
sku := skuByName[skuAnswers.SKU]
|
||||
machine = sku.Name
|
||||
|
||||
return machine, nil
|
||||
}
|
||||
180
cmd/ghcs/delete.go
Normal file
180
cmd/ghcs/delete.go
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
package ghcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AlecAivazis/survey/v2"
|
||||
"github.com/cli/cli/v2/internal/api"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type deleteOptions struct {
|
||||
deleteAll bool
|
||||
skipConfirm bool
|
||||
codespaceName string
|
||||
repoFilter string
|
||||
keepDays uint16
|
||||
|
||||
isInteractive bool
|
||||
now func() time.Time
|
||||
prompter prompter
|
||||
}
|
||||
|
||||
//go:generate moq -fmt goimports -rm -skip-ensure -out mock_prompter.go . prompter
|
||||
type prompter interface {
|
||||
Confirm(message string) (bool, error)
|
||||
}
|
||||
|
||||
func newDeleteCmd(app *App) *cobra.Command {
|
||||
opts := deleteOptions{
|
||||
isInteractive: hasTTY,
|
||||
now: time.Now,
|
||||
prompter: &surveyPrompter{},
|
||||
}
|
||||
|
||||
deleteCmd := &cobra.Command{
|
||||
Use: "delete",
|
||||
Short: "Delete a codespace",
|
||||
Args: noArgsConstraint,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if opts.deleteAll && opts.repoFilter != "" {
|
||||
return errors.New("both --all and --repo is not supported")
|
||||
}
|
||||
return app.Delete(cmd.Context(), opts)
|
||||
},
|
||||
}
|
||||
|
||||
deleteCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "Name of the codespace")
|
||||
deleteCmd.Flags().BoolVar(&opts.deleteAll, "all", false, "Delete all codespaces")
|
||||
deleteCmd.Flags().StringVarP(&opts.repoFilter, "repo", "r", "", "Delete codespaces for a `repository`")
|
||||
deleteCmd.Flags().BoolVarP(&opts.skipConfirm, "force", "f", false, "Skip confirmation for codespaces that contain unsaved changes")
|
||||
deleteCmd.Flags().Uint16Var(&opts.keepDays, "days", 0, "Delete codespaces older than `N` days")
|
||||
|
||||
return deleteCmd
|
||||
}
|
||||
|
||||
func (a *App) Delete(ctx context.Context, opts deleteOptions) error {
|
||||
user, err := a.apiClient.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
var codespaces []*api.Codespace
|
||||
nameFilter := opts.codespaceName
|
||||
if nameFilter == "" {
|
||||
codespaces, err = a.apiClient.ListCodespaces(ctx, user.Login)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting codespaces: %w", err)
|
||||
}
|
||||
|
||||
if !opts.deleteAll && opts.repoFilter == "" {
|
||||
c, err := chooseCodespaceFromList(ctx, codespaces)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error choosing codespace: %w", err)
|
||||
}
|
||||
nameFilter = c.Name
|
||||
}
|
||||
} else {
|
||||
// TODO: this token is discarded and then re-requested later in DeleteCodespace
|
||||
token, err := a.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting codespace token: %w", err)
|
||||
}
|
||||
|
||||
codespace, err := a.apiClient.GetCodespace(ctx, token, user.Login, nameFilter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error fetching codespace information: %w", err)
|
||||
}
|
||||
|
||||
codespaces = []*api.Codespace{codespace}
|
||||
}
|
||||
|
||||
codespacesToDelete := make([]*api.Codespace, 0, len(codespaces))
|
||||
lastUpdatedCutoffTime := opts.now().AddDate(0, 0, -int(opts.keepDays))
|
||||
for _, c := range codespaces {
|
||||
if nameFilter != "" && c.Name != nameFilter {
|
||||
continue
|
||||
}
|
||||
if opts.repoFilter != "" && !strings.EqualFold(c.RepositoryNWO, opts.repoFilter) {
|
||||
continue
|
||||
}
|
||||
if opts.keepDays > 0 {
|
||||
t, err := time.Parse(time.RFC3339, c.LastUsedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing last_used_at timestamp %q: %w", c.LastUsedAt, err)
|
||||
}
|
||||
if t.After(lastUpdatedCutoffTime) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !opts.skipConfirm {
|
||||
confirmed, err := confirmDeletion(opts.prompter, c, opts.isInteractive)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to confirm: %w", err)
|
||||
}
|
||||
if !confirmed {
|
||||
continue
|
||||
}
|
||||
}
|
||||
codespacesToDelete = append(codespacesToDelete, c)
|
||||
}
|
||||
|
||||
if len(codespacesToDelete) == 0 {
|
||||
return errors.New("no codespaces to delete")
|
||||
}
|
||||
|
||||
g := errgroup.Group{}
|
||||
for _, c := range codespacesToDelete {
|
||||
codespaceName := c.Name
|
||||
g.Go(func() error {
|
||||
if err := a.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil {
|
||||
_, _ = a.logger.Errorf("error deleting codespace %q: %v\n", codespaceName, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return errors.New("some codespaces failed to delete")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func confirmDeletion(p prompter, codespace *api.Codespace, isInteractive bool) (bool, error) {
|
||||
gs := codespace.Environment.GitStatus
|
||||
hasUnsavedChanges := gs.HasUncommitedChanges || gs.HasUnpushedChanges
|
||||
if !hasUnsavedChanges {
|
||||
return true, nil
|
||||
}
|
||||
if !isInteractive {
|
||||
return false, fmt.Errorf("codespace %s has unsaved changes (use --force to override)", codespace.Name)
|
||||
}
|
||||
return p.Confirm(fmt.Sprintf("Codespace %s has unsaved changes. OK to delete?", codespace.Name))
|
||||
}
|
||||
|
||||
type surveyPrompter struct{}
|
||||
|
||||
func (p *surveyPrompter) Confirm(message string) (bool, error) {
|
||||
var confirmed struct {
|
||||
Confirmed bool
|
||||
}
|
||||
q := []*survey.Question{
|
||||
{
|
||||
Name: "confirmed",
|
||||
Prompt: &survey.Confirm{
|
||||
Message: message,
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := ask(q, &confirmed); err != nil {
|
||||
return false, fmt.Errorf("failed to prompt: %w", err)
|
||||
}
|
||||
|
||||
return confirmed.Confirmed, nil
|
||||
}
|
||||
257
cmd/ghcs/delete_test.go
Normal file
257
cmd/ghcs/delete_test.go
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
package ghcs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/MakeNowJust/heredoc"
|
||||
"github.com/cli/cli/v2/cmd/ghcs/output"
|
||||
"github.com/cli/cli/v2/internal/api"
|
||||
)
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
user := &api.User{Login: "hubot"}
|
||||
now, _ := time.Parse(time.RFC3339, "2021-09-22T00:00:00Z")
|
||||
daysAgo := func(n int) string {
|
||||
return now.Add(time.Hour * -time.Duration(24*n)).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts deleteOptions
|
||||
codespaces []*api.Codespace
|
||||
confirms map[string]bool
|
||||
deleteErr error
|
||||
wantErr bool
|
||||
wantDeleted []string
|
||||
wantStdout string
|
||||
wantStderr string
|
||||
}{
|
||||
{
|
||||
name: "by name",
|
||||
opts: deleteOptions{
|
||||
codespaceName: "hubot-robawt-abc",
|
||||
},
|
||||
codespaces: []*api.Codespace{
|
||||
{
|
||||
Name: "hubot-robawt-abc",
|
||||
},
|
||||
},
|
||||
wantDeleted: []string{"hubot-robawt-abc"},
|
||||
},
|
||||
{
|
||||
name: "by repo",
|
||||
opts: deleteOptions{
|
||||
repoFilter: "monalisa/spoon-knife",
|
||||
},
|
||||
codespaces: []*api.Codespace{
|
||||
{
|
||||
Name: "monalisa-spoonknife-123",
|
||||
RepositoryNWO: "monalisa/Spoon-Knife",
|
||||
},
|
||||
{
|
||||
Name: "hubot-robawt-abc",
|
||||
RepositoryNWO: "hubot/ROBAWT",
|
||||
},
|
||||
{
|
||||
Name: "monalisa-spoonknife-c4f3",
|
||||
RepositoryNWO: "monalisa/Spoon-Knife",
|
||||
},
|
||||
},
|
||||
wantDeleted: []string{"monalisa-spoonknife-123", "monalisa-spoonknife-c4f3"},
|
||||
},
|
||||
{
|
||||
name: "unused",
|
||||
opts: deleteOptions{
|
||||
deleteAll: true,
|
||||
keepDays: 3,
|
||||
},
|
||||
codespaces: []*api.Codespace{
|
||||
{
|
||||
Name: "monalisa-spoonknife-123",
|
||||
LastUsedAt: daysAgo(1),
|
||||
},
|
||||
{
|
||||
Name: "hubot-robawt-abc",
|
||||
LastUsedAt: daysAgo(4),
|
||||
},
|
||||
{
|
||||
Name: "monalisa-spoonknife-c4f3",
|
||||
LastUsedAt: daysAgo(10),
|
||||
},
|
||||
},
|
||||
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"},
|
||||
},
|
||||
{
|
||||
name: "deletion failed",
|
||||
opts: deleteOptions{
|
||||
deleteAll: true,
|
||||
},
|
||||
codespaces: []*api.Codespace{
|
||||
{
|
||||
Name: "monalisa-spoonknife-123",
|
||||
},
|
||||
{
|
||||
Name: "hubot-robawt-abc",
|
||||
},
|
||||
},
|
||||
deleteErr: errors.New("aborted by test"),
|
||||
wantErr: true,
|
||||
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-123"},
|
||||
wantStderr: heredoc.Doc(`
|
||||
error deleting codespace "hubot-robawt-abc": aborted by test
|
||||
error deleting codespace "monalisa-spoonknife-123": aborted by test
|
||||
`),
|
||||
},
|
||||
{
|
||||
name: "with confirm",
|
||||
opts: deleteOptions{
|
||||
isInteractive: true,
|
||||
deleteAll: true,
|
||||
skipConfirm: false,
|
||||
},
|
||||
codespaces: []*api.Codespace{
|
||||
{
|
||||
Name: "monalisa-spoonknife-123",
|
||||
Environment: api.CodespaceEnvironment{
|
||||
GitStatus: api.CodespaceEnvironmentGitStatus{
|
||||
HasUnpushedChanges: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "hubot-robawt-abc",
|
||||
Environment: api.CodespaceEnvironment{
|
||||
GitStatus: api.CodespaceEnvironmentGitStatus{
|
||||
HasUncommitedChanges: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "monalisa-spoonknife-c4f3",
|
||||
Environment: api.CodespaceEnvironment{
|
||||
GitStatus: api.CodespaceEnvironmentGitStatus{
|
||||
HasUnpushedChanges: false,
|
||||
HasUncommitedChanges: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
confirms: map[string]bool{
|
||||
"Codespace monalisa-spoonknife-123 has unsaved changes. OK to delete?": false,
|
||||
"Codespace hubot-robawt-abc has unsaved changes. OK to delete?": true,
|
||||
},
|
||||
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
apiMock := &apiClientMock{
|
||||
GetUserFunc: func(_ context.Context) (*api.User, error) {
|
||||
return user, nil
|
||||
},
|
||||
DeleteCodespaceFunc: func(_ context.Context, userLogin, name string) error {
|
||||
if userLogin != user.Login {
|
||||
return fmt.Errorf("unexpected user %q", userLogin)
|
||||
}
|
||||
if tt.deleteErr != nil {
|
||||
return tt.deleteErr
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if tt.opts.codespaceName == "" {
|
||||
apiMock.ListCodespacesFunc = func(_ context.Context, userLogin string) ([]*api.Codespace, error) {
|
||||
if userLogin != user.Login {
|
||||
return nil, fmt.Errorf("unexpected user %q", userLogin)
|
||||
}
|
||||
return tt.codespaces, nil
|
||||
}
|
||||
} else {
|
||||
apiMock.GetCodespaceTokenFunc = func(_ context.Context, userLogin, name string) (string, error) {
|
||||
if userLogin != user.Login {
|
||||
return "", fmt.Errorf("unexpected user %q", userLogin)
|
||||
}
|
||||
return "CS_TOKEN", nil
|
||||
}
|
||||
apiMock.GetCodespaceFunc = func(_ context.Context, token, userLogin, name string) (*api.Codespace, error) {
|
||||
if userLogin != user.Login {
|
||||
return nil, fmt.Errorf("unexpected user %q", userLogin)
|
||||
}
|
||||
if token != "CS_TOKEN" {
|
||||
return nil, fmt.Errorf("unexpected token %q", token)
|
||||
}
|
||||
return tt.codespaces[0], nil
|
||||
}
|
||||
}
|
||||
opts := tt.opts
|
||||
opts.now = func() time.Time { return now }
|
||||
opts.prompter = &prompterMock{
|
||||
ConfirmFunc: func(msg string) (bool, error) {
|
||||
res, found := tt.confirms[msg]
|
||||
if !found {
|
||||
return false, fmt.Errorf("unexpected prompt %q", msg)
|
||||
}
|
||||
return res, nil
|
||||
},
|
||||
}
|
||||
|
||||
stdout := &bytes.Buffer{}
|
||||
stderr := &bytes.Buffer{}
|
||||
app := &App{
|
||||
apiClient: apiMock,
|
||||
logger: output.NewLogger(stdout, stderr, false),
|
||||
}
|
||||
err := app.Delete(context.Background(), opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if n := len(apiMock.GetUserCalls()); n != 1 {
|
||||
t.Errorf("GetUser invoked %d times, expected %d", n, 1)
|
||||
}
|
||||
var gotDeleted []string
|
||||
for _, delArgs := range apiMock.DeleteCodespaceCalls() {
|
||||
gotDeleted = append(gotDeleted, delArgs.Name)
|
||||
}
|
||||
sort.Strings(gotDeleted)
|
||||
if !sliceEquals(gotDeleted, tt.wantDeleted) {
|
||||
t.Errorf("deleted %q, want %q", gotDeleted, tt.wantDeleted)
|
||||
}
|
||||
if out := stdout.String(); out != tt.wantStdout {
|
||||
t.Errorf("stdout = %q, want %q", out, tt.wantStdout)
|
||||
}
|
||||
if out := sortLines(stderr.String()); out != tt.wantStderr {
|
||||
t.Errorf("stderr = %q, want %q", out, tt.wantStderr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func sliceEquals(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func sortLines(s string) string {
|
||||
trailing := ""
|
||||
if strings.HasSuffix(s, "\n") {
|
||||
s = strings.TrimSuffix(s, "\n")
|
||||
trailing = "\n"
|
||||
}
|
||||
lines := strings.Split(s, "\n")
|
||||
sort.Strings(lines)
|
||||
return strings.Join(lines, "\n") + trailing
|
||||
}
|
||||
63
cmd/ghcs/list.go
Normal file
63
cmd/ghcs/list.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
package ghcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/cli/cli/v2/cmd/ghcs/output"
|
||||
"github.com/cli/cli/v2/internal/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newListCmd(app *App) *cobra.Command {
|
||||
var asJSON bool
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List your codespaces",
|
||||
Args: noArgsConstraint,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return app.List(cmd.Context(), asJSON)
|
||||
},
|
||||
}
|
||||
|
||||
listCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON")
|
||||
|
||||
return listCmd
|
||||
}
|
||||
|
||||
func (a *App) List(ctx context.Context, asJSON bool) error {
|
||||
user, err := a.apiClient.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
codespaces, err := a.apiClient.ListCodespaces(ctx, user.Login)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting codespaces: %w", err)
|
||||
}
|
||||
|
||||
table := output.NewTable(os.Stdout, asJSON)
|
||||
table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"})
|
||||
for _, codespace := range codespaces {
|
||||
table.Append([]string{
|
||||
codespace.Name,
|
||||
codespace.RepositoryNWO,
|
||||
codespace.Branch + dirtyStar(codespace.Environment.GitStatus),
|
||||
codespace.Environment.State,
|
||||
codespace.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
table.Render()
|
||||
return nil
|
||||
}
|
||||
|
||||
func dirtyStar(status api.CodespaceEnvironmentGitStatus) string {
|
||||
if status.HasUncommitedChanges || status.HasUnpushedChanges {
|
||||
return "*"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
112
cmd/ghcs/logs.go
Normal file
112
cmd/ghcs/logs.go
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
package ghcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newLogsCmd(app *App) *cobra.Command {
|
||||
var (
|
||||
codespace string
|
||||
follow bool
|
||||
)
|
||||
|
||||
logsCmd := &cobra.Command{
|
||||
Use: "logs",
|
||||
Short: "Access codespace logs",
|
||||
Args: noArgsConstraint,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return app.Logs(cmd.Context(), codespace, follow)
|
||||
},
|
||||
}
|
||||
|
||||
logsCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace")
|
||||
logsCmd.Flags().BoolVarP(&follow, "follow", "f", false, "Tail and follow the logs")
|
||||
|
||||
return logsCmd
|
||||
}
|
||||
|
||||
func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err error) {
|
||||
// Ensure all child tasks (port forwarding, remote exec) terminate before return.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
user, err := a.apiClient.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting user: %w", err)
|
||||
}
|
||||
|
||||
authkeys := make(chan error, 1)
|
||||
go func() {
|
||||
authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login)
|
||||
}()
|
||||
|
||||
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get or choose codespace: %w", err)
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to Live Share: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
if err := <-authkeys; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure local port is listening before client (getPostCreateOutput) connects.
|
||||
listen, err := net.Listen("tcp", ":0") // arbitrary port
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listen.Close()
|
||||
localPort := listen.Addr().(*net.TCPAddr).Port
|
||||
|
||||
a.logger.Println("Fetching SSH Details...")
|
||||
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting ssh server details: %w", err)
|
||||
}
|
||||
|
||||
cmdType := "cat"
|
||||
if follow {
|
||||
cmdType = "tail -f"
|
||||
}
|
||||
|
||||
dst := fmt.Sprintf("%s@localhost", sshUser)
|
||||
cmd, err := codespaces.NewRemoteCommand(
|
||||
ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remote command: %w", err)
|
||||
}
|
||||
|
||||
tunnelClosed := make(chan error, 1)
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
|
||||
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
|
||||
}()
|
||||
|
||||
cmdDone := make(chan error, 1)
|
||||
go func() {
|
||||
cmdDone <- cmd.Run()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-tunnelClosed:
|
||||
return fmt.Errorf("connection closed: %w", err)
|
||||
case err := <-cmdDone:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error retrieving logs: %w", err)
|
||||
}
|
||||
|
||||
return nil // success
|
||||
}
|
||||
}
|
||||
54
cmd/ghcs/main/main.go
Normal file
54
cmd/ghcs/main/main.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/cli/cli/v2/cmd/ghcs"
|
||||
"github.com/cli/cli/v2/cmd/ghcs/output"
|
||||
"github.com/cli/cli/v2/internal/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func main() {
|
||||
token := os.Getenv("GITHUB_TOKEN")
|
||||
rootCmd := ghcs.NewRootCmd(ghcs.NewApp(
|
||||
output.NewLogger(os.Stdout, os.Stderr, false),
|
||||
api.New(token, http.DefaultClient),
|
||||
))
|
||||
|
||||
// Require GITHUB_TOKEN through a Cobra pre-run hook so that Cobra's help system for commands can still
|
||||
// function without the token set.
|
||||
oldPreRun := rootCmd.PersistentPreRunE
|
||||
rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
|
||||
if token == "" {
|
||||
return errTokenMissing
|
||||
}
|
||||
if oldPreRun != nil {
|
||||
return oldPreRun(cmd, args)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if cmd, err := rootCmd.ExecuteC(); err != nil {
|
||||
explainError(os.Stderr, err, cmd)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var errTokenMissing = errors.New("GITHUB_TOKEN is missing")
|
||||
|
||||
func explainError(w io.Writer, err error, cmd *cobra.Command) {
|
||||
if errors.Is(err, errTokenMissing) {
|
||||
fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo")
|
||||
fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, ghcs.ErrTooManyArgs) {
|
||||
_ = cmd.Usage()
|
||||
return
|
||||
}
|
||||
}
|
||||
659
cmd/ghcs/mock_api.go
Normal file
659
cmd/ghcs/mock_api.go
Normal file
|
|
@ -0,0 +1,659 @@
|
|||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package ghcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/cli/cli/v2/internal/api"
|
||||
)
|
||||
|
||||
// apiClientMock is a mock implementation of apiClient.
|
||||
//
|
||||
// func TestSomethingThatUsesapiClient(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked apiClient
|
||||
// mockedapiClient := &apiClientMock{
|
||||
// AuthorizedKeysFunc: func(ctx context.Context, user string) ([]byte, error) {
|
||||
// panic("mock out the AuthorizedKeys method")
|
||||
// },
|
||||
// CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
|
||||
// panic("mock out the CreateCodespace method")
|
||||
// },
|
||||
// DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error {
|
||||
// panic("mock out the DeleteCodespace method")
|
||||
// },
|
||||
// GetCodespaceFunc: func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) {
|
||||
// panic("mock out the GetCodespace method")
|
||||
// },
|
||||
// GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) {
|
||||
// panic("mock out the GetCodespaceRegionLocation method")
|
||||
// },
|
||||
// GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) {
|
||||
// panic("mock out the GetCodespaceRepositoryContents method")
|
||||
// },
|
||||
// GetCodespaceTokenFunc: func(ctx context.Context, user string, name string) (string, error) {
|
||||
// panic("mock out the GetCodespaceToken method")
|
||||
// },
|
||||
// GetCodespacesSKUsFunc: func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) {
|
||||
// panic("mock out the GetCodespacesSKUs method")
|
||||
// },
|
||||
// GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
|
||||
// panic("mock out the GetRepository method")
|
||||
// },
|
||||
// GetUserFunc: func(ctx context.Context) (*api.User, error) {
|
||||
// panic("mock out the GetUser method")
|
||||
// },
|
||||
// ListCodespacesFunc: func(ctx context.Context, user string) ([]*api.Codespace, error) {
|
||||
// panic("mock out the ListCodespaces method")
|
||||
// },
|
||||
// StartCodespaceFunc: func(ctx context.Context, token string, codespace *api.Codespace) error {
|
||||
// panic("mock out the StartCodespace method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedapiClient in code that requires apiClient
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type apiClientMock struct {
|
||||
// AuthorizedKeysFunc mocks the AuthorizedKeys method.
|
||||
AuthorizedKeysFunc func(ctx context.Context, user string) ([]byte, error)
|
||||
|
||||
// CreateCodespaceFunc mocks the CreateCodespace method.
|
||||
CreateCodespaceFunc func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
|
||||
|
||||
// DeleteCodespaceFunc mocks the DeleteCodespace method.
|
||||
DeleteCodespaceFunc func(ctx context.Context, user string, name string) error
|
||||
|
||||
// GetCodespaceFunc mocks the GetCodespace method.
|
||||
GetCodespaceFunc func(ctx context.Context, token string, user string, name string) (*api.Codespace, error)
|
||||
|
||||
// GetCodespaceRegionLocationFunc mocks the GetCodespaceRegionLocation method.
|
||||
GetCodespaceRegionLocationFunc func(ctx context.Context) (string, error)
|
||||
|
||||
// GetCodespaceRepositoryContentsFunc mocks the GetCodespaceRepositoryContents method.
|
||||
GetCodespaceRepositoryContentsFunc func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
|
||||
|
||||
// GetCodespaceTokenFunc mocks the GetCodespaceToken method.
|
||||
GetCodespaceTokenFunc func(ctx context.Context, user string, name string) (string, error)
|
||||
|
||||
// GetCodespacesSKUsFunc mocks the GetCodespacesSKUs method.
|
||||
GetCodespacesSKUsFunc func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error)
|
||||
|
||||
// GetRepositoryFunc mocks the GetRepository method.
|
||||
GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error)
|
||||
|
||||
// GetUserFunc mocks the GetUser method.
|
||||
GetUserFunc func(ctx context.Context) (*api.User, error)
|
||||
|
||||
// ListCodespacesFunc mocks the ListCodespaces method.
|
||||
ListCodespacesFunc func(ctx context.Context, user string) ([]*api.Codespace, error)
|
||||
|
||||
// StartCodespaceFunc mocks the StartCodespace method.
|
||||
StartCodespaceFunc func(ctx context.Context, token string, codespace *api.Codespace) error
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// AuthorizedKeys holds details about calls to the AuthorizedKeys method.
|
||||
AuthorizedKeys []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
// User is the user argument value.
|
||||
User string
|
||||
}
|
||||
// CreateCodespace holds details about calls to the CreateCodespace method.
|
||||
CreateCodespace []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
// Params is the params argument value.
|
||||
Params *api.CreateCodespaceParams
|
||||
}
|
||||
// DeleteCodespace holds details about calls to the DeleteCodespace method.
|
||||
DeleteCodespace []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
// User is the user argument value.
|
||||
User string
|
||||
// Name is the name argument value.
|
||||
Name string
|
||||
}
|
||||
// GetCodespace holds details about calls to the GetCodespace method.
|
||||
GetCodespace []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
// Token is the token argument value.
|
||||
Token string
|
||||
// User is the user argument value.
|
||||
User string
|
||||
// Name is the name argument value.
|
||||
Name string
|
||||
}
|
||||
// GetCodespaceRegionLocation holds details about calls to the GetCodespaceRegionLocation method.
|
||||
GetCodespaceRegionLocation []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
}
|
||||
// GetCodespaceRepositoryContents holds details about calls to the GetCodespaceRepositoryContents method.
|
||||
GetCodespaceRepositoryContents []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
// Codespace is the codespace argument value.
|
||||
Codespace *api.Codespace
|
||||
// Path is the path argument value.
|
||||
Path string
|
||||
}
|
||||
// GetCodespaceToken holds details about calls to the GetCodespaceToken method.
|
||||
GetCodespaceToken []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
// User is the user argument value.
|
||||
User string
|
||||
// Name is the name argument value.
|
||||
Name string
|
||||
}
|
||||
// GetCodespacesSKUs holds details about calls to the GetCodespacesSKUs method.
|
||||
GetCodespacesSKUs []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
// User is the user argument value.
|
||||
User *api.User
|
||||
// Repository is the repository argument value.
|
||||
Repository *api.Repository
|
||||
// Branch is the branch argument value.
|
||||
Branch string
|
||||
// Location is the location argument value.
|
||||
Location string
|
||||
}
|
||||
// GetRepository holds details about calls to the GetRepository method.
|
||||
GetRepository []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
// Nwo is the nwo argument value.
|
||||
Nwo string
|
||||
}
|
||||
// GetUser holds details about calls to the GetUser method.
|
||||
GetUser []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
}
|
||||
// ListCodespaces holds details about calls to the ListCodespaces method.
|
||||
ListCodespaces []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
// User is the user argument value.
|
||||
User string
|
||||
}
|
||||
// StartCodespace holds details about calls to the StartCodespace method.
|
||||
StartCodespace []struct {
|
||||
// Ctx is the ctx argument value.
|
||||
Ctx context.Context
|
||||
// Token is the token argument value.
|
||||
Token string
|
||||
// Codespace is the codespace argument value.
|
||||
Codespace *api.Codespace
|
||||
}
|
||||
}
|
||||
lockAuthorizedKeys sync.RWMutex
|
||||
lockCreateCodespace sync.RWMutex
|
||||
lockDeleteCodespace sync.RWMutex
|
||||
lockGetCodespace sync.RWMutex
|
||||
lockGetCodespaceRegionLocation sync.RWMutex
|
||||
lockGetCodespaceRepositoryContents sync.RWMutex
|
||||
lockGetCodespaceToken sync.RWMutex
|
||||
lockGetCodespacesSKUs sync.RWMutex
|
||||
lockGetRepository sync.RWMutex
|
||||
lockGetUser sync.RWMutex
|
||||
lockListCodespaces sync.RWMutex
|
||||
lockStartCodespace sync.RWMutex
|
||||
}
|
||||
|
||||
// AuthorizedKeys calls AuthorizedKeysFunc.
|
||||
func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
|
||||
if mock.AuthorizedKeysFunc == nil {
|
||||
panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
}{
|
||||
Ctx: ctx,
|
||||
User: user,
|
||||
}
|
||||
mock.lockAuthorizedKeys.Lock()
|
||||
mock.calls.AuthorizedKeys = append(mock.calls.AuthorizedKeys, callInfo)
|
||||
mock.lockAuthorizedKeys.Unlock()
|
||||
return mock.AuthorizedKeysFunc(ctx, user)
|
||||
}
|
||||
|
||||
// AuthorizedKeysCalls gets all the calls that were made to AuthorizedKeys.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.AuthorizedKeysCalls())
|
||||
func (mock *apiClientMock) AuthorizedKeysCalls() []struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
}
|
||||
mock.lockAuthorizedKeys.RLock()
|
||||
calls = mock.calls.AuthorizedKeys
|
||||
mock.lockAuthorizedKeys.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// CreateCodespace calls CreateCodespaceFunc.
|
||||
func (mock *apiClientMock) CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
|
||||
if mock.CreateCodespaceFunc == nil {
|
||||
panic("apiClientMock.CreateCodespaceFunc: method is nil but apiClient.CreateCodespace was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
Params *api.CreateCodespaceParams
|
||||
}{
|
||||
Ctx: ctx,
|
||||
Params: params,
|
||||
}
|
||||
mock.lockCreateCodespace.Lock()
|
||||
mock.calls.CreateCodespace = append(mock.calls.CreateCodespace, callInfo)
|
||||
mock.lockCreateCodespace.Unlock()
|
||||
return mock.CreateCodespaceFunc(ctx, params)
|
||||
}
|
||||
|
||||
// CreateCodespaceCalls gets all the calls that were made to CreateCodespace.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.CreateCodespaceCalls())
|
||||
func (mock *apiClientMock) CreateCodespaceCalls() []struct {
|
||||
Ctx context.Context
|
||||
Params *api.CreateCodespaceParams
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
Params *api.CreateCodespaceParams
|
||||
}
|
||||
mock.lockCreateCodespace.RLock()
|
||||
calls = mock.calls.CreateCodespace
|
||||
mock.lockCreateCodespace.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// DeleteCodespace calls DeleteCodespaceFunc.
|
||||
func (mock *apiClientMock) DeleteCodespace(ctx context.Context, user string, name string) error {
|
||||
if mock.DeleteCodespaceFunc == nil {
|
||||
panic("apiClientMock.DeleteCodespaceFunc: method is nil but apiClient.DeleteCodespace was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
Name string
|
||||
}{
|
||||
Ctx: ctx,
|
||||
User: user,
|
||||
Name: name,
|
||||
}
|
||||
mock.lockDeleteCodespace.Lock()
|
||||
mock.calls.DeleteCodespace = append(mock.calls.DeleteCodespace, callInfo)
|
||||
mock.lockDeleteCodespace.Unlock()
|
||||
return mock.DeleteCodespaceFunc(ctx, user, name)
|
||||
}
|
||||
|
||||
// DeleteCodespaceCalls gets all the calls that were made to DeleteCodespace.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.DeleteCodespaceCalls())
|
||||
func (mock *apiClientMock) DeleteCodespaceCalls() []struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
Name string
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
Name string
|
||||
}
|
||||
mock.lockDeleteCodespace.RLock()
|
||||
calls = mock.calls.DeleteCodespace
|
||||
mock.lockDeleteCodespace.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetCodespace calls GetCodespaceFunc.
|
||||
func (mock *apiClientMock) GetCodespace(ctx context.Context, token string, user string, name string) (*api.Codespace, error) {
|
||||
if mock.GetCodespaceFunc == nil {
|
||||
panic("apiClientMock.GetCodespaceFunc: method is nil but apiClient.GetCodespace was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
Token string
|
||||
User string
|
||||
Name string
|
||||
}{
|
||||
Ctx: ctx,
|
||||
Token: token,
|
||||
User: user,
|
||||
Name: name,
|
||||
}
|
||||
mock.lockGetCodespace.Lock()
|
||||
mock.calls.GetCodespace = append(mock.calls.GetCodespace, callInfo)
|
||||
mock.lockGetCodespace.Unlock()
|
||||
return mock.GetCodespaceFunc(ctx, token, user, name)
|
||||
}
|
||||
|
||||
// GetCodespaceCalls gets all the calls that were made to GetCodespace.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.GetCodespaceCalls())
|
||||
func (mock *apiClientMock) GetCodespaceCalls() []struct {
|
||||
Ctx context.Context
|
||||
Token string
|
||||
User string
|
||||
Name string
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
Token string
|
||||
User string
|
||||
Name string
|
||||
}
|
||||
mock.lockGetCodespace.RLock()
|
||||
calls = mock.calls.GetCodespace
|
||||
mock.lockGetCodespace.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetCodespaceRegionLocation calls GetCodespaceRegionLocationFunc.
|
||||
func (mock *apiClientMock) GetCodespaceRegionLocation(ctx context.Context) (string, error) {
|
||||
if mock.GetCodespaceRegionLocationFunc == nil {
|
||||
panic("apiClientMock.GetCodespaceRegionLocationFunc: method is nil but apiClient.GetCodespaceRegionLocation was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
}{
|
||||
Ctx: ctx,
|
||||
}
|
||||
mock.lockGetCodespaceRegionLocation.Lock()
|
||||
mock.calls.GetCodespaceRegionLocation = append(mock.calls.GetCodespaceRegionLocation, callInfo)
|
||||
mock.lockGetCodespaceRegionLocation.Unlock()
|
||||
return mock.GetCodespaceRegionLocationFunc(ctx)
|
||||
}
|
||||
|
||||
// GetCodespaceRegionLocationCalls gets all the calls that were made to GetCodespaceRegionLocation.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.GetCodespaceRegionLocationCalls())
|
||||
func (mock *apiClientMock) GetCodespaceRegionLocationCalls() []struct {
|
||||
Ctx context.Context
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
}
|
||||
mock.lockGetCodespaceRegionLocation.RLock()
|
||||
calls = mock.calls.GetCodespaceRegionLocation
|
||||
mock.lockGetCodespaceRegionLocation.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetCodespaceRepositoryContents calls GetCodespaceRepositoryContentsFunc.
|
||||
func (mock *apiClientMock) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) {
|
||||
if mock.GetCodespaceRepositoryContentsFunc == nil {
|
||||
panic("apiClientMock.GetCodespaceRepositoryContentsFunc: method is nil but apiClient.GetCodespaceRepositoryContents was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
Codespace *api.Codespace
|
||||
Path string
|
||||
}{
|
||||
Ctx: ctx,
|
||||
Codespace: codespace,
|
||||
Path: path,
|
||||
}
|
||||
mock.lockGetCodespaceRepositoryContents.Lock()
|
||||
mock.calls.GetCodespaceRepositoryContents = append(mock.calls.GetCodespaceRepositoryContents, callInfo)
|
||||
mock.lockGetCodespaceRepositoryContents.Unlock()
|
||||
return mock.GetCodespaceRepositoryContentsFunc(ctx, codespace, path)
|
||||
}
|
||||
|
||||
// GetCodespaceRepositoryContentsCalls gets all the calls that were made to GetCodespaceRepositoryContents.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.GetCodespaceRepositoryContentsCalls())
|
||||
func (mock *apiClientMock) GetCodespaceRepositoryContentsCalls() []struct {
|
||||
Ctx context.Context
|
||||
Codespace *api.Codespace
|
||||
Path string
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
Codespace *api.Codespace
|
||||
Path string
|
||||
}
|
||||
mock.lockGetCodespaceRepositoryContents.RLock()
|
||||
calls = mock.calls.GetCodespaceRepositoryContents
|
||||
mock.lockGetCodespaceRepositoryContents.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetCodespaceToken calls GetCodespaceTokenFunc.
|
||||
func (mock *apiClientMock) GetCodespaceToken(ctx context.Context, user string, name string) (string, error) {
|
||||
if mock.GetCodespaceTokenFunc == nil {
|
||||
panic("apiClientMock.GetCodespaceTokenFunc: method is nil but apiClient.GetCodespaceToken was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
Name string
|
||||
}{
|
||||
Ctx: ctx,
|
||||
User: user,
|
||||
Name: name,
|
||||
}
|
||||
mock.lockGetCodespaceToken.Lock()
|
||||
mock.calls.GetCodespaceToken = append(mock.calls.GetCodespaceToken, callInfo)
|
||||
mock.lockGetCodespaceToken.Unlock()
|
||||
return mock.GetCodespaceTokenFunc(ctx, user, name)
|
||||
}
|
||||
|
||||
// GetCodespaceTokenCalls gets all the calls that were made to GetCodespaceToken.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.GetCodespaceTokenCalls())
|
||||
func (mock *apiClientMock) GetCodespaceTokenCalls() []struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
Name string
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
Name string
|
||||
}
|
||||
mock.lockGetCodespaceToken.RLock()
|
||||
calls = mock.calls.GetCodespaceToken
|
||||
mock.lockGetCodespaceToken.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetCodespacesSKUs calls GetCodespacesSKUsFunc.
|
||||
func (mock *apiClientMock) GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) {
|
||||
if mock.GetCodespacesSKUsFunc == nil {
|
||||
panic("apiClientMock.GetCodespacesSKUsFunc: method is nil but apiClient.GetCodespacesSKUs was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
User *api.User
|
||||
Repository *api.Repository
|
||||
Branch string
|
||||
Location string
|
||||
}{
|
||||
Ctx: ctx,
|
||||
User: user,
|
||||
Repository: repository,
|
||||
Branch: branch,
|
||||
Location: location,
|
||||
}
|
||||
mock.lockGetCodespacesSKUs.Lock()
|
||||
mock.calls.GetCodespacesSKUs = append(mock.calls.GetCodespacesSKUs, callInfo)
|
||||
mock.lockGetCodespacesSKUs.Unlock()
|
||||
return mock.GetCodespacesSKUsFunc(ctx, user, repository, branch, location)
|
||||
}
|
||||
|
||||
// GetCodespacesSKUsCalls gets all the calls that were made to GetCodespacesSKUs.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.GetCodespacesSKUsCalls())
|
||||
func (mock *apiClientMock) GetCodespacesSKUsCalls() []struct {
|
||||
Ctx context.Context
|
||||
User *api.User
|
||||
Repository *api.Repository
|
||||
Branch string
|
||||
Location string
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
User *api.User
|
||||
Repository *api.Repository
|
||||
Branch string
|
||||
Location string
|
||||
}
|
||||
mock.lockGetCodespacesSKUs.RLock()
|
||||
calls = mock.calls.GetCodespacesSKUs
|
||||
mock.lockGetCodespacesSKUs.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetRepository calls GetRepositoryFunc.
|
||||
func (mock *apiClientMock) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) {
|
||||
if mock.GetRepositoryFunc == nil {
|
||||
panic("apiClientMock.GetRepositoryFunc: method is nil but apiClient.GetRepository was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
Nwo string
|
||||
}{
|
||||
Ctx: ctx,
|
||||
Nwo: nwo,
|
||||
}
|
||||
mock.lockGetRepository.Lock()
|
||||
mock.calls.GetRepository = append(mock.calls.GetRepository, callInfo)
|
||||
mock.lockGetRepository.Unlock()
|
||||
return mock.GetRepositoryFunc(ctx, nwo)
|
||||
}
|
||||
|
||||
// GetRepositoryCalls gets all the calls that were made to GetRepository.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.GetRepositoryCalls())
|
||||
func (mock *apiClientMock) GetRepositoryCalls() []struct {
|
||||
Ctx context.Context
|
||||
Nwo string
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
Nwo string
|
||||
}
|
||||
mock.lockGetRepository.RLock()
|
||||
calls = mock.calls.GetRepository
|
||||
mock.lockGetRepository.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetUser calls GetUserFunc.
|
||||
func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) {
|
||||
if mock.GetUserFunc == nil {
|
||||
panic("apiClientMock.GetUserFunc: method is nil but apiClient.GetUser was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
}{
|
||||
Ctx: ctx,
|
||||
}
|
||||
mock.lockGetUser.Lock()
|
||||
mock.calls.GetUser = append(mock.calls.GetUser, callInfo)
|
||||
mock.lockGetUser.Unlock()
|
||||
return mock.GetUserFunc(ctx)
|
||||
}
|
||||
|
||||
// GetUserCalls gets all the calls that were made to GetUser.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.GetUserCalls())
|
||||
func (mock *apiClientMock) GetUserCalls() []struct {
|
||||
Ctx context.Context
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
}
|
||||
mock.lockGetUser.RLock()
|
||||
calls = mock.calls.GetUser
|
||||
mock.lockGetUser.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// ListCodespaces calls ListCodespacesFunc.
|
||||
func (mock *apiClientMock) ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) {
|
||||
if mock.ListCodespacesFunc == nil {
|
||||
panic("apiClientMock.ListCodespacesFunc: method is nil but apiClient.ListCodespaces was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
}{
|
||||
Ctx: ctx,
|
||||
User: user,
|
||||
}
|
||||
mock.lockListCodespaces.Lock()
|
||||
mock.calls.ListCodespaces = append(mock.calls.ListCodespaces, callInfo)
|
||||
mock.lockListCodespaces.Unlock()
|
||||
return mock.ListCodespacesFunc(ctx, user)
|
||||
}
|
||||
|
||||
// ListCodespacesCalls gets all the calls that were made to ListCodespaces.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.ListCodespacesCalls())
|
||||
func (mock *apiClientMock) ListCodespacesCalls() []struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
User string
|
||||
}
|
||||
mock.lockListCodespaces.RLock()
|
||||
calls = mock.calls.ListCodespaces
|
||||
mock.lockListCodespaces.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// StartCodespace calls StartCodespaceFunc.
|
||||
func (mock *apiClientMock) StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error {
|
||||
if mock.StartCodespaceFunc == nil {
|
||||
panic("apiClientMock.StartCodespaceFunc: method is nil but apiClient.StartCodespace was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Ctx context.Context
|
||||
Token string
|
||||
Codespace *api.Codespace
|
||||
}{
|
||||
Ctx: ctx,
|
||||
Token: token,
|
||||
Codespace: codespace,
|
||||
}
|
||||
mock.lockStartCodespace.Lock()
|
||||
mock.calls.StartCodespace = append(mock.calls.StartCodespace, callInfo)
|
||||
mock.lockStartCodespace.Unlock()
|
||||
return mock.StartCodespaceFunc(ctx, token, codespace)
|
||||
}
|
||||
|
||||
// StartCodespaceCalls gets all the calls that were made to StartCodespace.
|
||||
// Check the length with:
|
||||
// len(mockedapiClient.StartCodespaceCalls())
|
||||
func (mock *apiClientMock) StartCodespaceCalls() []struct {
|
||||
Ctx context.Context
|
||||
Token string
|
||||
Codespace *api.Codespace
|
||||
} {
|
||||
var calls []struct {
|
||||
Ctx context.Context
|
||||
Token string
|
||||
Codespace *api.Codespace
|
||||
}
|
||||
mock.lockStartCodespace.RLock()
|
||||
calls = mock.calls.StartCodespace
|
||||
mock.lockStartCodespace.RUnlock()
|
||||
return calls
|
||||
}
|
||||
69
cmd/ghcs/mock_prompter.go
Normal file
69
cmd/ghcs/mock_prompter.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package ghcs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// prompterMock is a mock implementation of prompter.
|
||||
//
|
||||
// func TestSomethingThatUsesprompter(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked prompter
|
||||
// mockedprompter := &prompterMock{
|
||||
// ConfirmFunc: func(message string) (bool, error) {
|
||||
// panic("mock out the Confirm method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedprompter in code that requires prompter
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type prompterMock struct {
|
||||
// ConfirmFunc mocks the Confirm method.
|
||||
ConfirmFunc func(message string) (bool, error)
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// Confirm holds details about calls to the Confirm method.
|
||||
Confirm []struct {
|
||||
// Message is the message argument value.
|
||||
Message string
|
||||
}
|
||||
}
|
||||
lockConfirm sync.RWMutex
|
||||
}
|
||||
|
||||
// Confirm calls ConfirmFunc.
|
||||
func (mock *prompterMock) Confirm(message string) (bool, error) {
|
||||
if mock.ConfirmFunc == nil {
|
||||
panic("prompterMock.ConfirmFunc: method is nil but prompter.Confirm was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
Message string
|
||||
}{
|
||||
Message: message,
|
||||
}
|
||||
mock.lockConfirm.Lock()
|
||||
mock.calls.Confirm = append(mock.calls.Confirm, callInfo)
|
||||
mock.lockConfirm.Unlock()
|
||||
return mock.ConfirmFunc(message)
|
||||
}
|
||||
|
||||
// ConfirmCalls gets all the calls that were made to Confirm.
|
||||
// Check the length with:
|
||||
// len(mockedprompter.ConfirmCalls())
|
||||
func (mock *prompterMock) ConfirmCalls() []struct {
|
||||
Message string
|
||||
} {
|
||||
var calls []struct {
|
||||
Message string
|
||||
}
|
||||
mock.lockConfirm.RLock()
|
||||
calls = mock.calls.Confirm
|
||||
mock.lockConfirm.RUnlock()
|
||||
return calls
|
||||
}
|
||||
55
cmd/ghcs/output/format_json.go
Normal file
55
cmd/ghcs/output/format_json.go
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
package output
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type jsonwriter struct {
|
||||
w io.Writer
|
||||
pretty bool
|
||||
cols []string
|
||||
data []interface{}
|
||||
}
|
||||
|
||||
func (j *jsonwriter) SetHeader(cols []string) {
|
||||
j.cols = cols
|
||||
}
|
||||
|
||||
func (j *jsonwriter) Append(values []string) {
|
||||
row := make(map[string]string)
|
||||
for i, v := range values {
|
||||
row[camelize(j.cols[i])] = v
|
||||
}
|
||||
j.data = append(j.data, row)
|
||||
}
|
||||
|
||||
func (j *jsonwriter) Render() {
|
||||
enc := json.NewEncoder(j.w)
|
||||
if j.pretty {
|
||||
enc.SetIndent("", " ")
|
||||
}
|
||||
_ = enc.Encode(j.data)
|
||||
}
|
||||
|
||||
func camelize(s string) string {
|
||||
var b strings.Builder
|
||||
capitalizeNext := false
|
||||
for i, r := range s {
|
||||
if r == ' ' {
|
||||
capitalizeNext = true
|
||||
continue
|
||||
}
|
||||
if capitalizeNext {
|
||||
b.WriteRune(unicode.ToUpper(r))
|
||||
capitalizeNext = false
|
||||
} else if i == 0 {
|
||||
b.WriteRune(unicode.ToLower(r))
|
||||
} else {
|
||||
b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
31
cmd/ghcs/output/format_table.go
Normal file
31
cmd/ghcs/output/format_table.go
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
package output
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
type Table interface {
|
||||
SetHeader([]string)
|
||||
Append([]string)
|
||||
Render()
|
||||
}
|
||||
|
||||
func NewTable(w io.Writer, asJSON bool) Table {
|
||||
isTTY := isTTY(w)
|
||||
if asJSON {
|
||||
return &jsonwriter{w: w, pretty: isTTY}
|
||||
}
|
||||
if isTTY {
|
||||
return tablewriter.NewWriter(w)
|
||||
}
|
||||
return &tabwriter{w: w}
|
||||
}
|
||||
|
||||
func isTTY(w io.Writer) bool {
|
||||
f, ok := w.(*os.File)
|
||||
return ok && term.IsTerminal(int(f.Fd()))
|
||||
}
|
||||
25
cmd/ghcs/output/format_tsv.go
Normal file
25
cmd/ghcs/output/format_tsv.go
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
package output
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
type tabwriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (j *tabwriter) SetHeader([]string) {}
|
||||
|
||||
func (j *tabwriter) Append(values []string) {
|
||||
var sep string
|
||||
for i, v := range values {
|
||||
if i == 1 {
|
||||
sep = "\t"
|
||||
}
|
||||
fmt.Fprintf(j.w, "%s%s", sep, v)
|
||||
}
|
||||
fmt.Fprint(j.w, "\n")
|
||||
}
|
||||
|
||||
func (j *tabwriter) Render() {}
|
||||
74
cmd/ghcs/output/logger.go
Normal file
74
cmd/ghcs/output/logger.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
package output
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// NewLogger returns a Logger that will write to the given stdout/stderr writers.
|
||||
// Disable the Logger to prevent it from writing to stdout in a TTY environment.
|
||||
func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger {
|
||||
return &Logger{
|
||||
out: stdout,
|
||||
errout: stderr,
|
||||
enabled: !disabled && isTTY(stdout),
|
||||
}
|
||||
}
|
||||
|
||||
// Logger writes to the given stdout/stderr writers.
|
||||
// If not enabled, Print functions will noop but Error functions will continue
|
||||
// to write to the stderr writer.
|
||||
type Logger struct {
|
||||
mu sync.Mutex // guards the writers
|
||||
out io.Writer
|
||||
errout io.Writer
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// Print writes the arguments to the stdout writer.
|
||||
func (l *Logger) Print(v ...interface{}) (int, error) {
|
||||
if !l.enabled {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return fmt.Fprint(l.out, v...)
|
||||
}
|
||||
|
||||
// Println writes the arguments to the stdout writer with a newline at the end.
|
||||
func (l *Logger) Println(v ...interface{}) (int, error) {
|
||||
if !l.enabled {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return fmt.Fprintln(l.out, v...)
|
||||
}
|
||||
|
||||
// Printf writes the formatted arguments to the stdout writer.
|
||||
func (l *Logger) Printf(f string, v ...interface{}) (int, error) {
|
||||
if !l.enabled {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return fmt.Fprintf(l.out, f, v...)
|
||||
}
|
||||
|
||||
// Errorf writes the formatted arguments to the stderr writer.
|
||||
func (l *Logger) Errorf(f string, v ...interface{}) (int, error) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return fmt.Fprintf(l.errout, f, v...)
|
||||
}
|
||||
|
||||
// Errorln writes the arguments to the stderr writer with a newline at the end.
|
||||
func (l *Logger) Errorln(v ...interface{}) (int, error) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return fmt.Fprintln(l.errout, v...)
|
||||
}
|
||||
330
cmd/ghcs/ports.go
Normal file
330
cmd/ghcs/ports.go
Normal file
|
|
@ -0,0 +1,330 @@
|
|||
package ghcs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cli/cli/v2/cmd/ghcs/output"
|
||||
"github.com/cli/cli/v2/internal/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/liveshare"
|
||||
"github.com/muhammadmuzzammil1998/jsonc"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// newPortsCmd returns a Cobra "ports" command that displays a table of available ports,
|
||||
// according to the specified flags.
|
||||
func newPortsCmd(app *App) *cobra.Command {
|
||||
var (
|
||||
codespace string
|
||||
asJSON bool
|
||||
)
|
||||
|
||||
portsCmd := &cobra.Command{
|
||||
Use: "ports",
|
||||
Short: "List ports in a codespace",
|
||||
Args: noArgsConstraint,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return app.ListPorts(cmd.Context(), codespace, asJSON)
|
||||
},
|
||||
}
|
||||
|
||||
portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace")
|
||||
portsCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON")
|
||||
|
||||
portsCmd.AddCommand(newPortsPublicCmd(app))
|
||||
portsCmd.AddCommand(newPortsPrivateCmd(app))
|
||||
portsCmd.AddCommand(newPortsForwardCmd(app))
|
||||
|
||||
return portsCmd
|
||||
}
|
||||
|
||||
// ListPorts lists known ports in a codespace.
|
||||
func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool) (err error) {
|
||||
user, err := a.apiClient.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
|
||||
if err != nil {
|
||||
// TODO(josebalius): remove special handling of this error here and it other places
|
||||
if err == errNoCodespaces {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("error choosing codespace: %w", err)
|
||||
}
|
||||
|
||||
devContainerCh := getDevContainer(ctx, a.apiClient, codespace)
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
a.logger.Println("Loading ports...")
|
||||
ports, err := session.GetSharedServers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting ports of shared servers: %w", err)
|
||||
}
|
||||
|
||||
devContainerResult := <-devContainerCh
|
||||
if devContainerResult.err != nil {
|
||||
// Warn about failure to read the devcontainer file. Not a ghcs command error.
|
||||
_, _ = a.logger.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error())
|
||||
}
|
||||
|
||||
table := output.NewTable(os.Stdout, asJSON)
|
||||
table.SetHeader([]string{"Label", "Port", "Public", "Browse URL"})
|
||||
for _, port := range ports {
|
||||
sourcePort := strconv.Itoa(port.SourcePort)
|
||||
var portName string
|
||||
if devContainerResult.devContainer != nil {
|
||||
if attributes, ok := devContainerResult.devContainer.PortAttributes[sourcePort]; ok {
|
||||
portName = attributes.Label
|
||||
}
|
||||
}
|
||||
|
||||
table.Append([]string{
|
||||
portName,
|
||||
sourcePort,
|
||||
strings.ToUpper(strconv.FormatBool(port.IsPublic)),
|
||||
fmt.Sprintf("https://%s-%s.githubpreview.dev/", codespace.Name, sourcePort),
|
||||
})
|
||||
}
|
||||
table.Render()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type devContainerResult struct {
|
||||
devContainer *devContainer
|
||||
err error
|
||||
}
|
||||
|
||||
type devContainer struct {
|
||||
PortAttributes map[string]portAttribute `json:"portsAttributes"`
|
||||
}
|
||||
|
||||
type portAttribute struct {
|
||||
Label string `json:"label"`
|
||||
}
|
||||
|
||||
func getDevContainer(ctx context.Context, apiClient apiClient, codespace *api.Codespace) <-chan devContainerResult {
|
||||
ch := make(chan devContainerResult, 1)
|
||||
go func() {
|
||||
contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json")
|
||||
if err != nil {
|
||||
ch <- devContainerResult{nil, fmt.Errorf("error getting content: %w", err)}
|
||||
return
|
||||
}
|
||||
|
||||
if contents == nil {
|
||||
ch <- devContainerResult{nil, nil}
|
||||
return
|
||||
}
|
||||
|
||||
convertedJSON := normalizeJSON(jsonc.ToJSON(contents))
|
||||
if !jsonc.Valid(convertedJSON) {
|
||||
ch <- devContainerResult{nil, errors.New("failed to convert json to standard json")}
|
||||
return
|
||||
}
|
||||
|
||||
var container devContainer
|
||||
if err := json.Unmarshal(convertedJSON, &container); err != nil {
|
||||
ch <- devContainerResult{nil, fmt.Errorf("error unmarshaling: %w", err)}
|
||||
return
|
||||
}
|
||||
|
||||
ch <- devContainerResult{&container, nil}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
// newPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public.
|
||||
func newPortsPublicCmd(app *App) *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "public <port>",
|
||||
Short: "Mark port as public",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
codespace, err := cmd.Flags().GetString("codespace")
|
||||
if err != nil {
|
||||
// should only happen if flag is not defined
|
||||
// or if the flag is not of string type
|
||||
// since it's a persistent flag that we control it should never happen
|
||||
return fmt.Errorf("get codespace flag: %w", err)
|
||||
}
|
||||
|
||||
return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], true)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private.
|
||||
func newPortsPrivateCmd(app *App) *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "private <port>",
|
||||
Short: "Mark port as private",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
codespace, err := cmd.Flags().GetString("codespace")
|
||||
if err != nil {
|
||||
// should only happen if flag is not defined
|
||||
// or if the flag is not of string type
|
||||
// since it's a persistent flag that we control it should never happen
|
||||
return fmt.Errorf("get codespace flag: %w", err)
|
||||
}
|
||||
|
||||
return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], false)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePort string, public bool) (err error) {
|
||||
user, err := a.apiClient.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
|
||||
if err != nil {
|
||||
if err == errNoCodespaces {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("error getting codespace: %w", err)
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
port, err := strconv.Atoi(sourcePort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading port number: %w", err)
|
||||
}
|
||||
|
||||
if err := session.UpdateSharedVisibility(ctx, port, public); err != nil {
|
||||
return fmt.Errorf("error update port to public: %w", err)
|
||||
}
|
||||
|
||||
state := "PUBLIC"
|
||||
if !public {
|
||||
state = "PRIVATE"
|
||||
}
|
||||
a.logger.Printf("Port %s is now %s.\n", sourcePort, state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of
|
||||
// port pairs from the codespace to localhost.
|
||||
func newPortsForwardCmd(app *App) *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "forward <remote-port>:<local-port>...",
|
||||
Short: "Forward ports",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
codespace, err := cmd.Flags().GetString("codespace")
|
||||
if err != nil {
|
||||
// should only happen if flag is not defined
|
||||
// or if the flag is not of string type
|
||||
// since it's a persistent flag that we control it should never happen
|
||||
return fmt.Errorf("get codespace flag: %w", err)
|
||||
}
|
||||
|
||||
return app.ForwardPorts(cmd.Context(), codespace, args)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []string) (err error) {
|
||||
portPairs, err := getPortPairs(ports)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get port pairs: %w", err)
|
||||
}
|
||||
|
||||
user, err := a.apiClient.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
|
||||
if err != nil {
|
||||
if err == errNoCodespaces {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("error getting codespace: %w", err)
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
// Run forwarding of all ports concurrently, aborting all of
|
||||
// them at the first failure, including cancellation of the context.
|
||||
group, ctx := errgroup.WithContext(ctx)
|
||||
for _, pair := range portPairs {
|
||||
pair := pair
|
||||
group.Go(func() error {
|
||||
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listen.Close()
|
||||
a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)
|
||||
name := fmt.Sprintf("share-%d", pair.remote)
|
||||
fwd := liveshare.NewPortForwarder(session, name, pair.remote)
|
||||
return fwd.ForwardToListener(ctx, listen) // error always non-nil
|
||||
})
|
||||
}
|
||||
return group.Wait() // first error
|
||||
}
|
||||
|
||||
type portPair struct {
|
||||
remote, local int
|
||||
}
|
||||
|
||||
// getPortPairs parses a list of strings of form "%d:%d" into pairs of (remote, local) numbers.
|
||||
func getPortPairs(ports []string) ([]portPair, error) {
|
||||
pp := make([]portPair, 0, len(ports))
|
||||
|
||||
for _, portString := range ports {
|
||||
parts := strings.Split(portString, ":")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("port pair: %q is not valid", portString)
|
||||
}
|
||||
|
||||
remote, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return pp, fmt.Errorf("convert remote port to int: %w", err)
|
||||
}
|
||||
|
||||
local, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return pp, fmt.Errorf("convert local port to int: %w", err)
|
||||
}
|
||||
|
||||
pp = append(pp, portPair{remote, local})
|
||||
}
|
||||
|
||||
return pp, nil
|
||||
}
|
||||
|
||||
func normalizeJSON(j []byte) []byte {
|
||||
// remove trailing commas
|
||||
return bytes.ReplaceAll(j, []byte("},}"), []byte("}}"))
|
||||
}
|
||||
30
cmd/ghcs/root.go
Normal file
30
cmd/ghcs/root.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package ghcs
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number.
|
||||
|
||||
func NewRootCmd(app *App) *cobra.Command {
|
||||
root := &cobra.Command{
|
||||
Use: "ghcs",
|
||||
SilenceUsage: true, // don't print usage message after each error (see #80)
|
||||
SilenceErrors: false, // print errors automatically so that main need not
|
||||
Long: `Unofficial CLI tool to manage GitHub Codespaces.
|
||||
|
||||
Running commands requires the GITHUB_TOKEN environment variable to be set to a
|
||||
token to access the GitHub API with.`,
|
||||
Version: version,
|
||||
}
|
||||
|
||||
root.AddCommand(newCodeCmd(app))
|
||||
root.AddCommand(newCreateCmd(app))
|
||||
root.AddCommand(newDeleteCmd(app))
|
||||
root.AddCommand(newListCmd(app))
|
||||
root.AddCommand(newLogsCmd(app))
|
||||
root.AddCommand(newPortsCmd(app))
|
||||
root.AddCommand(newSSHCmd(app))
|
||||
|
||||
return root
|
||||
}
|
||||
105
cmd/ghcs/ssh.go
Normal file
105
cmd/ghcs/ssh.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
package ghcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newSSHCmd(app *App) *cobra.Command {
|
||||
var sshProfile, codespaceName string
|
||||
var sshServerPort int
|
||||
|
||||
sshCmd := &cobra.Command{
|
||||
Use: "ssh [flags] [--] [ssh-flags] [command]",
|
||||
Short: "SSH into a codespace",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort)
|
||||
},
|
||||
}
|
||||
|
||||
sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "Name of the SSH profile to use")
|
||||
sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number (0 => pick unused)")
|
||||
sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Name of the codespace")
|
||||
|
||||
return sshCmd
|
||||
}
|
||||
|
||||
// SSH opens an ssh session or runs an ssh command in a codespace.
|
||||
func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) {
|
||||
// Ensure all child tasks (e.g. port forwarding) terminate before return.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
user, err := a.apiClient.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
authkeys := make(chan error, 1)
|
||||
go func() {
|
||||
authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login)
|
||||
}()
|
||||
|
||||
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get or choose codespace: %w", err)
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
if err := <-authkeys; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.logger.Println("Fetching SSH Details...")
|
||||
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting ssh server details: %w", err)
|
||||
}
|
||||
|
||||
usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell
|
||||
|
||||
// Ensure local port is listening before client (Shell) connects.
|
||||
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localSSHServerPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listen.Close()
|
||||
localSSHServerPort = listen.Addr().(*net.TCPAddr).Port
|
||||
|
||||
connectDestination := sshProfile
|
||||
if connectDestination == "" {
|
||||
connectDestination = fmt.Sprintf("%s@localhost", sshUser)
|
||||
}
|
||||
|
||||
a.logger.Println("Ready...")
|
||||
tunnelClosed := make(chan error, 1)
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
|
||||
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil
|
||||
}()
|
||||
|
||||
shellClosed := make(chan error, 1)
|
||||
go func() {
|
||||
shellClosed <- codespaces.Shell(ctx, a.logger, sshArgs, localSSHServerPort, connectDestination, usingCustomPort)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-tunnelClosed:
|
||||
return fmt.Errorf("tunnel closed: %w", err)
|
||||
case err := <-shellClosed:
|
||||
if err != nil {
|
||||
return fmt.Errorf("shell closed: %w", err)
|
||||
}
|
||||
return nil // success
|
||||
}
|
||||
}
|
||||
13
go.mod
13
go.mod
|
|
@ -12,9 +12,11 @@ require (
|
|||
github.com/cli/safeexec v1.0.0
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0
|
||||
github.com/creack/pty v1.1.13
|
||||
github.com/fatih/camelcase v1.0.0
|
||||
github.com/gabriel-vasile/mimetype v1.1.2
|
||||
github.com/google/go-cmp v0.5.5
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/hashicorp/go-version v1.2.1
|
||||
github.com/henvic/httpretty v0.0.6
|
||||
github.com/itchyny/gojq v0.12.4
|
||||
|
|
@ -24,17 +26,24 @@ require (
|
|||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
|
||||
github.com/muesli/reflow v0.2.1-0.20210502190812-c80126ec2ad5
|
||||
github.com/muesli/termenv v0.8.1
|
||||
github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/opentracing/opentracing-go v1.1.0
|
||||
github.com/shurcooL/githubv4 v0.0.0-20200928013246-d292edc3691b
|
||||
github.com/shurcooL/graphql v0.0.0-20181231061246-d48a9a75455f
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/sourcegraph/jsonrpc2 v0.1.0
|
||||
github.com/spf13/cobra v1.2.1
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/stretchr/objx v0.1.1 // indirect
|
||||
github.com/stretchr/testify v1.7.0
|
||||
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
|
||||
golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1
|
||||
golang.org/x/term v0.0.0-20210503060354-a79de5458b56
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
|
||||
)
|
||||
|
||||
replace github.com/shurcooL/graphql => github.com/cli/shurcooL-graphql v0.0.0-20200707151639-0f7232a2bf7e
|
||||
|
||||
replace golang.org/x/crypto => github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03
|
||||
|
|
|
|||
28
go.sum
28
go.sum
|
|
@ -73,6 +73,8 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn
|
|||
github.com/cli/browser v1.0.0/go.mod h1:IEWkHYbLjkhtjwwWlwTHW2lGxeS5gezEQBMLTwDHf5Q=
|
||||
github.com/cli/browser v1.1.0 h1:xOZBfkfY9L9vMBgqb1YwRirGu6QFaQ5dP/vXt5ENSOY=
|
||||
github.com/cli/browser v1.1.0/go.mod h1:HKMQAt9t12kov91Mn7RfZxyJQQgWgyS/3SZswlZ5iTI=
|
||||
github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 h1:3f4uHLfWx4/WlnMPXGai03eoWAI+oGHJwr+5OXfxCr8=
|
||||
github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
github.com/cli/oauth v0.8.0 h1:YTFgPXSTvvDUFti3tR4o6q7Oll2SnQ9ztLwCAn4/IOA=
|
||||
github.com/cli/oauth v0.8.0/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4=
|
||||
github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI=
|
||||
|
|
@ -103,6 +105,8 @@ github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5y
|
|||
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/fatih/camelcase v1.0.0 h1:hxNvNX/xYBp0ovncs8WyWZrOrpBNub/JfaMvbURyft8=
|
||||
github.com/fatih/camelcase v1.0.0/go.mod h1:yN2Sb0lFhZJUdVvtELVWefmrXpuZESvPmqwoZc+/fpc=
|
||||
github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys=
|
||||
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
|
|
@ -182,6 +186,9 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m
|
|||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
|
||||
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
|
||||
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/graph-gophers/graphql-go v0.0.0-20200622220639-c1d9693c95a6/go.mod h1:9CQHMSxwO4MprSdzoIEobiHpoLtHm77vfxsvsIN5Vuc=
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q=
|
||||
|
|
@ -273,8 +280,11 @@ github.com/muesli/reflow v0.2.1-0.20210502190812-c80126ec2ad5 h1:T+Fc6qGlSfM+z0J
|
|||
github.com/muesli/reflow v0.2.1-0.20210502190812-c80126ec2ad5/go.mod h1:Xk+z4oIWdQqJzsxyjgl3P22oYZnHdZ8FFTHAQQt5BMQ=
|
||||
github.com/muesli/termenv v0.8.1 h1:9q230czSP3DHVpkaPDXGp0TOfAwyjyYwXlUCQxQSaBk=
|
||||
github.com/muesli/termenv v0.8.1/go.mod h1:kzt/D/4a88RoheZmwfqorY3A+tnsSMA9HJC/fQSFKo0=
|
||||
github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38 h1:0FrBxrkJ0hVembTb/e4EU5Ml6vLcOusAqymmYISg5Uo=
|
||||
github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38/go.mod h1:saF2fIVw4banK0H4+/EuqfFLpRnoy5S+ECwTOCcRcSU=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU=
|
||||
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
|
||||
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
||||
github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
|
||||
|
|
@ -300,8 +310,12 @@ github.com/shurcooL/githubv4 v0.0.0-20200928013246-d292edc3691b h1:0/ecDXh/HTHRt
|
|||
github.com/shurcooL/githubv4 v0.0.0-20200928013246-d292edc3691b/go.mod h1:hAF0iLZy4td2EX+/8Tw+4nodhlMrwN3HupfaXj3zkGo=
|
||||
github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
|
||||
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
|
||||
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
|
||||
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
|
||||
github.com/sourcegraph/jsonrpc2 v0.1.0 h1:ohJHjZ+PcaLxDUjqk2NC3tIGsVa5bXThe1ZheSXOjuk=
|
||||
github.com/sourcegraph/jsonrpc2 v0.1.0/go.mod h1:ZafdZgk/axhT1cvZAPOhw+95nz2I/Ra5qMlU4gTRwIo=
|
||||
github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I=
|
||||
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
|
||||
github.com/spf13/cobra v1.2.1 h1:+KmjbUw1hriSNMF55oPrkZcb27aECyrj8V2ytv7kWDw=
|
||||
|
|
@ -344,16 +358,6 @@ go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
|
|||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||
go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo=
|
||||
golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 h1:pLI5jrR7OSLijeIDcmRxNmw2api+jEfxLoykJVice/E=
|
||||
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
|
|
@ -396,7 +400,6 @@ golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73r
|
|||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
|
|
@ -497,8 +500,9 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b h1:qh4f65QIVFjq9eBURLEYWqaEXmOyqdUyiBSgaXWccWk=
|
||||
golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210503060354-a79de5458b56 h1:b8jxX3zqjpqb2LklXPzKSGJhzyxCOZSz8ncv8Nv+y7w=
|
||||
golang.org/x/term v0.0.0-20210503060354-a79de5458b56/go.mod h1:tfny5GFUkzUvx4ps4ajbZsCe5lw1metzhBm9T3x7oIY=
|
||||
|
|
|
|||
640
internal/api/api.go
Normal file
640
internal/api/api.go
Normal file
|
|
@ -0,0 +1,640 @@
|
|||
package api
|
||||
|
||||
// For descriptions of service interfaces, see:
|
||||
// - https://online.visualstudio.com/api/swagger (for visualstudio.com)
|
||||
// - https://docs.github.com/en/rest/reference/repos (for api.github.com)
|
||||
// - https://github.com/github/github/blob/master/app/api/codespaces.rb (for vscs_internal)
|
||||
// TODO(adonovan): replace the last link with a public doc URL when available.
|
||||
|
||||
// TODO(adonovan): a possible reorganization would be to split this
|
||||
// file into three internal packages, one per backend service, and to
|
||||
// rename api.API to github.Client:
|
||||
//
|
||||
// - github.GetUser(github.Client)
|
||||
// - github.GetRepository(Client)
|
||||
// - github.ReadFile(Client, nwo, branch, path) // was GetCodespaceRepositoryContents
|
||||
// - github.AuthorizedKeys(Client, user)
|
||||
// - codespaces.Create(Client, user, repo, sku, branch, location)
|
||||
// - codespaces.Delete(Client, user, token, name)
|
||||
// - codespaces.Get(Client, token, owner, name)
|
||||
// - codespaces.GetMachineTypes(Client, user, repo, branch, location)
|
||||
// - codespaces.GetToken(Client, login, name)
|
||||
// - codespaces.List(Client, user)
|
||||
// - codespaces.Start(Client, token, codespace)
|
||||
// - visualstudio.GetRegionLocation(http.Client) // no dependency on github
|
||||
//
|
||||
// This would make the meaning of each operation clearer.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
const githubAPI = "https://api.github.com"
|
||||
|
||||
type API struct {
|
||||
token string
|
||||
client httpClient
|
||||
githubAPI string
|
||||
}
|
||||
|
||||
type httpClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func New(token string, httpClient httpClient) *API {
|
||||
return &API{
|
||||
token: token,
|
||||
client: httpClient,
|
||||
githubAPI: githubAPI,
|
||||
}
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Login string `json:"login"`
|
||||
}
|
||||
|
||||
func (a *API) GetUser(ctx context.Context) (*User, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/user", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
a.setHeaders(req)
|
||||
resp, err := a.do(ctx, req, "/user")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, jsonErrorResponse(b)
|
||||
}
|
||||
|
||||
var response User
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func jsonErrorResponse(b []byte) error {
|
||||
var response struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return fmt.Errorf("error unmarshaling error response: %w", err)
|
||||
}
|
||||
|
||||
return errors.New(response.Message)
|
||||
}
|
||||
|
||||
type Repository struct {
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/repos/"+strings.ToLower(nwo), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
a.setHeaders(req)
|
||||
resp, err := a.do(ctx, req, "/repos/*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, jsonErrorResponse(b)
|
||||
}
|
||||
|
||||
var response Repository
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
type Codespace struct {
|
||||
Name string `json:"name"`
|
||||
GUID string `json:"guid"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastUsedAt string `json:"last_used_at"`
|
||||
Branch string `json:"branch"`
|
||||
RepositoryName string `json:"repository_name"`
|
||||
RepositoryNWO string `json:"repository_nwo"`
|
||||
OwnerLogin string `json:"owner_login"`
|
||||
Environment CodespaceEnvironment `json:"environment"`
|
||||
}
|
||||
|
||||
type CodespaceEnvironment struct {
|
||||
State string `json:"state"`
|
||||
Connection CodespaceEnvironmentConnection `json:"connection"`
|
||||
GitStatus CodespaceEnvironmentGitStatus `json:"gitStatus"`
|
||||
}
|
||||
|
||||
type CodespaceEnvironmentGitStatus struct {
|
||||
Ahead int `json:"ahead"`
|
||||
Behind int `json:"behind"`
|
||||
Branch string `json:"branch"`
|
||||
Commit string `json:"commit"`
|
||||
HasUnpushedChanges bool `json:"hasUnpushedChanges"`
|
||||
HasUncommitedChanges bool `json:"hasUncommitedChanges"`
|
||||
}
|
||||
|
||||
const (
|
||||
CodespaceEnvironmentStateAvailable = "Available"
|
||||
)
|
||||
|
||||
type CodespaceEnvironmentConnection struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
SessionToken string `json:"sessionToken"`
|
||||
RelayEndpoint string `json:"relayEndpoint"`
|
||||
RelaySAS string `json:"relaySas"`
|
||||
HostPublicKeys []string `json:"hostPublicKeys"`
|
||||
}
|
||||
|
||||
func (a *API) ListCodespaces(ctx context.Context, user string) ([]*Codespace, error) {
|
||||
req, err := http.NewRequest(
|
||||
http.MethodGet, a.githubAPI+"/vscs_internal/user/"+user+"/codespaces", nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
a.setHeaders(req)
|
||||
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, jsonErrorResponse(b)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Codespaces []*Codespace `json:"codespaces"`
|
||||
}
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling response: %w", err)
|
||||
}
|
||||
return response.Codespaces, nil
|
||||
}
|
||||
|
||||
type getCodespaceTokenRequest struct {
|
||||
MintRepositoryToken bool `json:"mint_repository_token"`
|
||||
}
|
||||
|
||||
type getCodespaceTokenResponse struct {
|
||||
RepositoryToken string `json:"repository_token"`
|
||||
}
|
||||
|
||||
// ErrNotProvisioned is returned by GetCodespacesToken to indicate that the
|
||||
// creation of a codespace is not yet complete and that the caller should try again.
|
||||
var ErrNotProvisioned = errors.New("codespace not provisioned")
|
||||
|
||||
func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName string) (string, error) {
|
||||
reqBody, err := json.Marshal(getCodespaceTokenRequest{true})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error preparing request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(
|
||||
http.MethodPost,
|
||||
a.githubAPI+"/vscs_internal/user/"+ownerLogin+"/codespaces/"+codespaceName+"/token",
|
||||
bytes.NewBuffer(reqBody),
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
a.setHeaders(req)
|
||||
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*/token")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode == http.StatusUnprocessableEntity {
|
||||
return "", ErrNotProvisioned
|
||||
}
|
||||
|
||||
return "", jsonErrorResponse(b)
|
||||
}
|
||||
|
||||
var response getCodespaceTokenResponse
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return "", fmt.Errorf("error unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
return response.RepositoryToken, nil
|
||||
}
|
||||
|
||||
func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) (*Codespace, error) {
|
||||
req, err := http.NewRequest(
|
||||
http.MethodGet,
|
||||
a.githubAPI+"/vscs_internal/user/"+owner+"/codespaces/"+codespace,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
// TODO: use a.setHeaders()
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, jsonErrorResponse(b)
|
||||
}
|
||||
|
||||
var response Codespace
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codespace) error {
|
||||
req, err := http.NewRequest(
|
||||
http.MethodPost,
|
||||
a.githubAPI+"/vscs_internal/proxy/environments/"+codespace.GUID+"/start",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
// TODO: use a.setHeaders()
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, err := a.do(ctx, req, "/vscs_internal/proxy/environments/*/start")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Error response may be a numeric code or a JSON {"message": "..."}.
|
||||
if bytes.HasPrefix(b, []byte("{")) {
|
||||
return jsonErrorResponse(b) // probably JSON
|
||||
}
|
||||
if len(b) > 100 {
|
||||
b = append(b[:97], "..."...)
|
||||
}
|
||||
if strings.TrimSpace(string(b)) == "7" {
|
||||
// Non-HTTP 200 with error code 7 (EnvironmentNotShutdown) is benign.
|
||||
// Ignore it.
|
||||
} else {
|
||||
return fmt.Errorf("failed to start codespace: %s", b)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type getCodespaceRegionLocationResponse struct {
|
||||
Current string `json:"current"`
|
||||
}
|
||||
|
||||
func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, "https://online.visualstudio.com/api/v1/locations", nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := a.do(ctx, req, req.URL.String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", jsonErrorResponse(b)
|
||||
}
|
||||
|
||||
var response getCodespaceRegionLocationResponse
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return "", fmt.Errorf("error unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
return response.Current, nil
|
||||
}
|
||||
|
||||
type SKU struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Repository, branch, location string) ([]*SKU, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
q.Add("location", location)
|
||||
q.Add("ref", branch)
|
||||
q.Add("repository_id", strconv.Itoa(repository.ID))
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
a.setHeaders(req)
|
||||
resp, err := a.do(ctx, req, "/vscs_internal/user/*/skus")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, jsonErrorResponse(b)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
SKUs []*SKU `json:"skus"`
|
||||
}
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
return response.SKUs, nil
|
||||
}
|
||||
|
||||
// CreateCodespaceParams are the required parameters for provisioning a Codespace.
|
||||
type CreateCodespaceParams struct {
|
||||
User string
|
||||
RepositoryID int
|
||||
Branch, Machine, Location string
|
||||
}
|
||||
|
||||
// CreateCodespace creates a codespace with the given parameters and returns a non-nil error if it
|
||||
// fails to create.
|
||||
func (a *API) CreateCodespace(ctx context.Context, params *CreateCodespaceParams) (*Codespace, error) {
|
||||
codespace, err := a.startCreate(
|
||||
ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location,
|
||||
)
|
||||
if err != errProvisioningInProgress {
|
||||
return codespace, err
|
||||
}
|
||||
|
||||
// errProvisioningInProgress indicates that codespace creation did not complete
|
||||
// within the GitHub API RPC time limit (10s), so it continues asynchronously.
|
||||
// We must poll the server to discover the outcome.
|
||||
ctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
token, err := a.GetCodespaceToken(ctx, params.User, codespace.Name)
|
||||
if err != nil {
|
||||
if err == ErrNotProvisioned {
|
||||
// Do nothing. We expect this to fail until the codespace is provisioned
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to get codespace token: %w", err)
|
||||
}
|
||||
|
||||
codespace, err = a.GetCodespace(ctx, token, params.User, codespace.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get codespace: %w", err)
|
||||
}
|
||||
|
||||
return codespace, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type startCreateRequest struct {
|
||||
RepositoryID int `json:"repository_id"`
|
||||
Ref string `json:"ref"`
|
||||
Location string `json:"location"`
|
||||
SkuName string `json:"sku_name"`
|
||||
}
|
||||
|
||||
var errProvisioningInProgress = errors.New("provisioning in progress")
|
||||
|
||||
// startCreate starts the creation of a codespace.
|
||||
// It may return success or an error, or errProvisioningInProgress indicating that the operation
|
||||
// did not complete before the GitHub API's time limit for RPCs (10s), in which case the caller
|
||||
// must poll the server to learn the outcome.
|
||||
func (a *API) startCreate(ctx context.Context, user string, repository int, sku, branch, location string) (*Codespace, error) {
|
||||
requestBody, err := json.Marshal(startCreateRequest{repository, branch, location, sku})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, a.githubAPI+"/vscs_internal/user/"+user+"/codespaces", bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
a.setHeaders(req)
|
||||
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case resp.StatusCode > http.StatusAccepted:
|
||||
return nil, jsonErrorResponse(b)
|
||||
case resp.StatusCode == http.StatusAccepted:
|
||||
return nil, errProvisioningInProgress // RPC finished before result of creation known
|
||||
}
|
||||
|
||||
var response Codespace
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (a *API) DeleteCodespace(ctx context.Context, user string, codespaceName string) error {
|
||||
token, err := a.GetCodespaceToken(ctx, user, codespaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting codespace token: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodDelete, a.githubAPI+"/vscs_internal/user/"+user+"/codespaces/"+codespaceName, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
// TODO: use a.setHeaders()
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode > http.StatusAccepted {
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
return jsonErrorResponse(b)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type getCodespaceRepositoryContentsResponse struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Codespace, path string) ([]byte, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/repos/"+codespace.RepositoryNWO+"/contents/"+path, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
q.Add("ref", codespace.Branch)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
a.setHeaders(req)
|
||||
resp, err := a.do(ctx, req, "/repos/*/contents/*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, jsonErrorResponse(b)
|
||||
}
|
||||
|
||||
var response getCodespaceRepositoryContentsResponse
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(response.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decoding content: %w", err)
|
||||
}
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// AuthorizedKeys returns the public keys (in ~/.ssh/authorized_keys
|
||||
// format) registered by the specified GitHub user.
|
||||
func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
|
||||
url := fmt.Sprintf("https://github.com/%s.keys", user)
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.do(ctx, req, "/user.keys")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("server returned %s", resp.Status)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) {
|
||||
// TODO(adonovan): use NewRequestWithContext(ctx) and drop ctx parameter.
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, spanName)
|
||||
defer span.Finish()
|
||||
req = req.WithContext(ctx)
|
||||
return a.client.Do(req)
|
||||
}
|
||||
|
||||
func (a *API) setHeaders(req *http.Request) {
|
||||
if a.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+a.token)
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
}
|
||||
50
internal/api/api_test.go
Normal file
50
internal/api/api_test.go
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestListCodespaces(t *testing.T) {
|
||||
codespaces := []*Codespace{
|
||||
{
|
||||
Name: "testcodespace",
|
||||
CreatedAt: "2021-08-09T10:10:24+02:00",
|
||||
LastUsedAt: "2021-08-09T13:10:24+02:00",
|
||||
},
|
||||
}
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := struct {
|
||||
Codespaces []*Codespace `json:"codespaces"`
|
||||
}{
|
||||
Codespaces: codespaces,
|
||||
}
|
||||
data, _ := json.Marshal(response)
|
||||
fmt.Fprint(w, string(data))
|
||||
}))
|
||||
defer svr.Close()
|
||||
|
||||
api := API{
|
||||
githubAPI: svr.URL,
|
||||
client: &http.Client{},
|
||||
token: "faketoken",
|
||||
}
|
||||
ctx := context.TODO()
|
||||
codespaces, err := api.ListCodespaces(ctx, "testuser")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(codespaces) != 1 {
|
||||
t.Fatalf("expected 1 codespace, got %d", len(codespaces))
|
||||
}
|
||||
|
||||
if codespaces[0].Name != "testcodespace" {
|
||||
t.Fatalf("expected testcodespace, got %s", codespaces[0].Name)
|
||||
}
|
||||
|
||||
}
|
||||
77
internal/codespaces/codespaces.go
Normal file
77
internal/codespaces/codespaces.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package codespaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cli/cli/v2/internal/api"
|
||||
"github.com/cli/cli/v2/internal/liveshare"
|
||||
)
|
||||
|
||||
type logger interface {
|
||||
Print(v ...interface{}) (int, error)
|
||||
Println(v ...interface{}) (int, error)
|
||||
}
|
||||
|
||||
func connectionReady(codespace *api.Codespace) bool {
|
||||
return codespace.Environment.Connection.SessionID != "" &&
|
||||
codespace.Environment.Connection.SessionToken != "" &&
|
||||
codespace.Environment.Connection.RelayEndpoint != "" &&
|
||||
codespace.Environment.Connection.RelaySAS != "" &&
|
||||
codespace.Environment.State == api.CodespaceEnvironmentStateAvailable
|
||||
}
|
||||
|
||||
type apiClient interface {
|
||||
GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error)
|
||||
GetCodespaceToken(ctx context.Context, user, codespace string) (string, error)
|
||||
StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error
|
||||
}
|
||||
|
||||
// ConnectToLiveshare waits for a Codespace to become running,
|
||||
// and connects to it using a Live Share session.
|
||||
func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) {
|
||||
var startedCodespace bool
|
||||
if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable {
|
||||
startedCodespace = true
|
||||
log.Print("Starting your codespace...")
|
||||
if err := apiClient.StartCodespace(ctx, token, codespace); err != nil {
|
||||
return nil, fmt.Errorf("error starting codespace: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for retries := 0; !connectionReady(codespace); retries++ {
|
||||
if retries > 1 {
|
||||
if retries%2 == 0 {
|
||||
log.Print(".")
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
if retries == 30 {
|
||||
return nil, errors.New("timed out while waiting for the codespace to start")
|
||||
}
|
||||
|
||||
var err error
|
||||
codespace, err = apiClient.GetCodespace(ctx, token, userLogin, codespace.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting codespace: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if startedCodespace {
|
||||
fmt.Print("\n")
|
||||
}
|
||||
|
||||
log.Println("Connecting to your codespace...")
|
||||
|
||||
return liveshare.Connect(ctx, liveshare.Options{
|
||||
SessionID: codespace.Environment.Connection.SessionID,
|
||||
SessionToken: codespace.Environment.Connection.SessionToken,
|
||||
RelaySAS: codespace.Environment.Connection.RelaySAS,
|
||||
RelayEndpoint: codespace.Environment.Connection.RelayEndpoint,
|
||||
HostPublicKeys: codespace.Environment.Connection.HostPublicKeys,
|
||||
})
|
||||
}
|
||||
90
internal/codespaces/ssh.go
Normal file
90
internal/codespaces/ssh.go
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
package codespaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Shell runs an interactive secure shell over an existing
|
||||
// port-forwarding session. It runs until the shell is terminated
|
||||
// (including by cancellation of the context).
|
||||
func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error {
|
||||
cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create ssh command: %w", err)
|
||||
}
|
||||
|
||||
if usingCustomPort {
|
||||
log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " "))
|
||||
}
|
||||
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// NewRemoteCommand returns an exec.Cmd that will securely run a shell
|
||||
// command on the remote machine.
|
||||
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) {
|
||||
cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs)
|
||||
return cmd, err
|
||||
}
|
||||
|
||||
// newSSHCommand populates an exec.Cmd to run a command (or if blank,
|
||||
// an interactive shell) over ssh.
|
||||
func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) {
|
||||
connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"}
|
||||
|
||||
// The ssh command syntax is: ssh [flags] user@host command [args...]
|
||||
// There is no way to specify the user@host destination as a flag.
|
||||
// Unfortunately, that means we need to know which user-provided words are
|
||||
// SSH flags and which are command arguments so that we can place
|
||||
// them before or after the destination, and that means we need to know all
|
||||
// the flags and their arities.
|
||||
cmdArgs, command, err := parseSSHArgs(cmdArgs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cmdArgs = append(cmdArgs, connArgs...)
|
||||
cmdArgs = append(cmdArgs, "-C") // Compression
|
||||
cmdArgs = append(cmdArgs, dst) // user@host
|
||||
|
||||
if command != nil {
|
||||
cmdArgs = append(cmdArgs, command...)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "ssh", cmdArgs...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
return cmd, connArgs, nil
|
||||
}
|
||||
|
||||
// parseSSHArgs parses SSH arguments into two distinct slices of flags and command.
|
||||
// It returns an error if a unary flag is provided without an argument.
|
||||
func parseSSHArgs(args []string) (cmdArgs, command []string, err error) {
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
|
||||
// if we've started parsing the command, set it to the rest of the args
|
||||
if !strings.HasPrefix(arg, "-") {
|
||||
command = args[i:]
|
||||
break
|
||||
}
|
||||
|
||||
cmdArgs = append(cmdArgs, arg)
|
||||
if len(arg) == 2 && strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) {
|
||||
if i++; i == len(args) {
|
||||
return nil, nil, fmt.Errorf("ssh flag: %s requires an argument", arg)
|
||||
}
|
||||
|
||||
cmdArgs = append(cmdArgs, args[i])
|
||||
}
|
||||
}
|
||||
|
||||
return cmdArgs, command, nil
|
||||
}
|
||||
105
internal/codespaces/ssh_test.go
Normal file
105
internal/codespaces/ssh_test.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
package codespaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseSSHArgs(t *testing.T) {
|
||||
type testCase struct {
|
||||
Args []string
|
||||
ParsedArgs []string
|
||||
Command []string
|
||||
Error string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{}, // empty test case
|
||||
{
|
||||
Args: []string{"-X", "-Y"},
|
||||
ParsedArgs: []string{"-X", "-Y"},
|
||||
Command: nil,
|
||||
},
|
||||
{
|
||||
Args: []string{"-X", "-Y", "-o", "someoption=test"},
|
||||
ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"},
|
||||
Command: nil,
|
||||
},
|
||||
{
|
||||
Args: []string{"-X", "-Y", "-o", "someoption=test", "somecommand"},
|
||||
ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"},
|
||||
Command: []string{"somecommand"},
|
||||
},
|
||||
{
|
||||
Args: []string{"-X", "-Y", "-o", "someoption=test", "echo", "test"},
|
||||
ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"},
|
||||
Command: []string{"echo", "test"},
|
||||
},
|
||||
{
|
||||
Args: []string{"somecommand"},
|
||||
ParsedArgs: []string{},
|
||||
Command: []string{"somecommand"},
|
||||
},
|
||||
{
|
||||
Args: []string{"echo", "test"},
|
||||
ParsedArgs: []string{},
|
||||
Command: []string{"echo", "test"},
|
||||
},
|
||||
{
|
||||
Args: []string{"-v", "echo", "hello", "world"},
|
||||
ParsedArgs: []string{"-v"},
|
||||
Command: []string{"echo", "hello", "world"},
|
||||
},
|
||||
{
|
||||
Args: []string{"-L", "-l"},
|
||||
ParsedArgs: []string{"-L", "-l"},
|
||||
Command: nil,
|
||||
},
|
||||
{
|
||||
Args: []string{"-v", "echo", "-n", "test"},
|
||||
ParsedArgs: []string{"-v"},
|
||||
Command: []string{"echo", "-n", "test"},
|
||||
},
|
||||
{
|
||||
Args: []string{"-v", "echo", "-b", "test"},
|
||||
ParsedArgs: []string{"-v"},
|
||||
Command: []string{"echo", "-b", "test"},
|
||||
},
|
||||
{
|
||||
Args: []string{"-b"},
|
||||
ParsedArgs: nil,
|
||||
Command: nil,
|
||||
Error: "ssh flag: -b requires an argument",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tcase := range testCases {
|
||||
args, command, err := parseSSHArgs(tcase.Args)
|
||||
if tcase.Error != "" {
|
||||
if err == nil {
|
||||
t.Errorf("expected error and got nil: %#v", tcase)
|
||||
}
|
||||
|
||||
if err.Error() != tcase.Error {
|
||||
t.Errorf("error does not match expected error, got: '%s', expected: '%s'", err.Error(), tcase.Error)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v on test case: %#v", err, tcase)
|
||||
continue
|
||||
}
|
||||
|
||||
argsStr, parsedArgsStr := fmt.Sprintf("%s", args), fmt.Sprintf("%s", tcase.ParsedArgs)
|
||||
if argsStr != parsedArgsStr {
|
||||
t.Errorf("args do not match parsed args. got: '%s', expected: '%s'", argsStr, parsedArgsStr)
|
||||
}
|
||||
|
||||
commandStr, parsedCommandStr := fmt.Sprintf("%s", command), fmt.Sprintf("%s", tcase.Command)
|
||||
if commandStr != parsedCommandStr {
|
||||
t.Errorf("command does not match parsed command. got: '%s', expected: '%s'", commandStr, parsedCommandStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
118
internal/codespaces/states.go
Normal file
118
internal/codespaces/states.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
package codespaces
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cli/cli/v2/internal/api"
|
||||
"github.com/cli/cli/v2/internal/liveshare"
|
||||
)
|
||||
|
||||
// PostCreateStateStatus is a string value representing the different statuses a state can have.
|
||||
type PostCreateStateStatus string
|
||||
|
||||
func (p PostCreateStateStatus) String() string {
|
||||
return strings.Title(string(p))
|
||||
}
|
||||
|
||||
const (
|
||||
PostCreateStateRunning PostCreateStateStatus = "running"
|
||||
PostCreateStateSuccess PostCreateStateStatus = "succeeded"
|
||||
PostCreateStateFailed PostCreateStateStatus = "failed"
|
||||
)
|
||||
|
||||
// PostCreateState is a combination of a state and status value that is captured
|
||||
// during codespace creation.
|
||||
type PostCreateState struct {
|
||||
Name string `json:"name"`
|
||||
Status PostCreateStateStatus `json:"status"`
|
||||
}
|
||||
|
||||
// PollPostCreateStates watches for state changes in a codespace,
|
||||
// and calls the supplied poller for each batch of state changes.
|
||||
// It runs until it encounters an error, including cancellation of the context.
|
||||
func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) {
|
||||
token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting codespace token: %w", err)
|
||||
}
|
||||
|
||||
session, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to Live Share: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := session.Close(); err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
// Ensure local port is listening before client (getPostCreateOutput) connects.
|
||||
listen, err := net.Listen("tcp", ":0") // arbitrary port
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
localPort := listen.Addr().(*net.TCPAddr).Port
|
||||
|
||||
log.Println("Fetching SSH Details...")
|
||||
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting ssh server details: %w", err)
|
||||
}
|
||||
|
||||
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
|
||||
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
|
||||
}()
|
||||
|
||||
t := time.NewTicker(1 * time.Second)
|
||||
defer t.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
|
||||
case err := <-tunnelClosed:
|
||||
return fmt.Errorf("connection failed: %w", err)
|
||||
|
||||
case <-t.C:
|
||||
states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get post create output: %w", err)
|
||||
}
|
||||
|
||||
poller(states)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) {
|
||||
cmd, err := NewRemoteCommand(
|
||||
ctx, tunnelPort, fmt.Sprintf("%s@localhost", user),
|
||||
"cat /workspaces/.codespaces/shared/postCreateOutput.json",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remote command: %w", err)
|
||||
}
|
||||
|
||||
stdout := new(bytes.Buffer)
|
||||
cmd.Stdout = stdout
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("run command: %w", err)
|
||||
}
|
||||
var output struct {
|
||||
Steps []PostCreateState `json:"steps"`
|
||||
}
|
||||
if err := json.Unmarshal(stdout.Bytes(), &output); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal output: %w", err)
|
||||
}
|
||||
|
||||
return output.Steps, nil
|
||||
}
|
||||
150
internal/liveshare/client.go
Normal file
150
internal/liveshare/client.go
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
// Package liveshare is a Go client library for the Visual Studio Live Share
|
||||
// service, which provides collaborative, distibuted editing and debugging.
|
||||
// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview.
|
||||
//
|
||||
// It provides the ability for a Go program to connect to a Live Share
|
||||
// workspace (Connect), to expose a TCP port on a remote host
|
||||
// (UpdateSharedVisibility), to start an SSH server listening on an
|
||||
// exposed port (StartSSHServer), and to forward connections between
|
||||
// the remote port and a local listening TCP port (ForwardToListener)
|
||||
// or a local Go reader/writer (Forward).
|
||||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// An Options specifies Live Share connection parameters.
|
||||
type Options struct {
|
||||
SessionID string
|
||||
SessionToken string // token for SSH session
|
||||
RelaySAS string
|
||||
RelayEndpoint string
|
||||
HostPublicKeys []string
|
||||
TLSConfig *tls.Config // (optional)
|
||||
}
|
||||
|
||||
// uri returns a websocket URL for the specified options.
|
||||
func (opts *Options) uri(action string) (string, error) {
|
||||
if opts.SessionID == "" {
|
||||
return "", errors.New("SessionID is required")
|
||||
}
|
||||
if opts.RelaySAS == "" {
|
||||
return "", errors.New("RelaySAS is required")
|
||||
}
|
||||
if opts.RelayEndpoint == "" {
|
||||
return "", errors.New("RelayEndpoint is required")
|
||||
}
|
||||
|
||||
sas := url.QueryEscape(opts.RelaySAS)
|
||||
uri := opts.RelayEndpoint
|
||||
uri = strings.Replace(uri, "sb:", "wss:", -1)
|
||||
uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1)
|
||||
uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas
|
||||
return uri, nil
|
||||
}
|
||||
|
||||
// Connect connects to a Live Share workspace specified by the
|
||||
// options, and returns a session representing the connection.
|
||||
// The caller must call the session's Close method to end the session.
|
||||
func Connect(ctx context.Context, opts Options) (*Session, error) {
|
||||
uri, err := opts.uri("connect")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
|
||||
defer span.Finish()
|
||||
|
||||
sock := newSocket(uri, opts.TLSConfig)
|
||||
if err := sock.connect(ctx); err != nil {
|
||||
return nil, fmt.Errorf("error connecting websocket: %w", err)
|
||||
}
|
||||
|
||||
if opts.SessionToken == "" {
|
||||
return nil, errors.New("SessionToken is required")
|
||||
}
|
||||
ssh := newSSHSession(opts.SessionToken, opts.HostPublicKeys, sock)
|
||||
if err := ssh.connect(ctx); err != nil {
|
||||
return nil, fmt.Errorf("error connecting to ssh session: %w", err)
|
||||
}
|
||||
|
||||
rpc := newRPCClient(ssh)
|
||||
rpc.connect(ctx)
|
||||
|
||||
args := joinWorkspaceArgs{
|
||||
ID: opts.SessionID,
|
||||
ConnectionMode: "local",
|
||||
JoiningUserSessionToken: opts.SessionToken,
|
||||
ClientCapabilities: clientCapabilities{
|
||||
IsNonInteractive: false,
|
||||
},
|
||||
}
|
||||
var result joinWorkspaceResult
|
||||
if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil {
|
||||
return nil, fmt.Errorf("error joining Live Share workspace: %w", err)
|
||||
}
|
||||
|
||||
return &Session{ssh: ssh, rpc: rpc}, nil
|
||||
}
|
||||
|
||||
type clientCapabilities struct {
|
||||
IsNonInteractive bool `json:"isNonInteractive"`
|
||||
}
|
||||
|
||||
type joinWorkspaceArgs struct {
|
||||
ID string `json:"id"`
|
||||
ConnectionMode string `json:"connectionMode"`
|
||||
JoiningUserSessionToken string `json:"joiningUserSessionToken"`
|
||||
ClientCapabilities clientCapabilities `json:"clientCapabilities"`
|
||||
}
|
||||
|
||||
type joinWorkspaceResult struct {
|
||||
SessionNumber int `json:"sessionNumber"`
|
||||
}
|
||||
|
||||
// A channelID is an identifier for an exposed port on a remote
|
||||
// container that may be used to open an SSH channel to it.
|
||||
type channelID struct {
|
||||
name, condition string
|
||||
}
|
||||
|
||||
func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) {
|
||||
type getStreamArgs struct {
|
||||
StreamName string `json:"streamName"`
|
||||
Condition string `json:"condition"`
|
||||
}
|
||||
args := getStreamArgs{
|
||||
StreamName: id.name,
|
||||
Condition: id.condition,
|
||||
}
|
||||
var streamID string
|
||||
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
|
||||
return nil, fmt.Errorf("error getting stream id: %w", err)
|
||||
}
|
||||
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
|
||||
defer span.Finish()
|
||||
_ = ctx // ctx is not currently used
|
||||
|
||||
channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
requestType := fmt.Sprintf("stream-transport-%s", streamID)
|
||||
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
|
||||
return nil, fmt.Errorf("error sending channel request: %w", err)
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
}
|
||||
72
internal/liveshare/client_test.go
Normal file
72
internal/liveshare/client_test.go
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
livesharetest "github.com/cli/cli/v2/internal/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
opts := Options{
|
||||
SessionID: "session-id",
|
||||
SessionToken: "session-token",
|
||||
RelaySAS: "relay-sas",
|
||||
HostPublicKeys: []string{livesharetest.SSHPublicKey},
|
||||
}
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
var joinWorkspaceReq joinWorkspaceArgs
|
||||
if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling req: %w", err)
|
||||
}
|
||||
if joinWorkspaceReq.ID != opts.SessionID {
|
||||
return nil, errors.New("connection session id does not match")
|
||||
}
|
||||
if joinWorkspaceReq.ConnectionMode != "local" {
|
||||
return nil, errors.New("connection mode is not local")
|
||||
}
|
||||
if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken {
|
||||
return nil, errors.New("connection user token does not match")
|
||||
}
|
||||
if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false {
|
||||
return nil, errors.New("non interactive is not false")
|
||||
}
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
|
||||
server, err := livesharetest.NewServer(
|
||||
livesharetest.WithPassword(opts.SessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
livesharetest.WithRelaySAS(opts.RelaySAS),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("error creating Live Share server: %w", err)
|
||||
}
|
||||
defer server.Close()
|
||||
opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
_, err := Connect(ctx, opts) // ignore session
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-server.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
56
internal/liveshare/options_test.go
Normal file
56
internal/liveshare/options_test.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBadOptions(t *testing.T) {
|
||||
goodOptions := Options{
|
||||
SessionID: "sess-id",
|
||||
SessionToken: "sess-token",
|
||||
RelaySAS: "sas",
|
||||
RelayEndpoint: "endpoint",
|
||||
}
|
||||
|
||||
opts := goodOptions
|
||||
opts.SessionID = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = goodOptions
|
||||
opts.SessionToken = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = goodOptions
|
||||
opts.RelaySAS = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = goodOptions
|
||||
opts.RelayEndpoint = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = Options{}
|
||||
checkBadOptions(t, opts)
|
||||
}
|
||||
|
||||
func checkBadOptions(t *testing.T, opts Options) {
|
||||
if _, err := Connect(context.Background(), opts); err == nil {
|
||||
t.Errorf("Connect(%+v): no error", opts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsURI(t *testing.T) {
|
||||
opts := Options{
|
||||
SessionID: "sess-id",
|
||||
SessionToken: "sess-token",
|
||||
RelaySAS: "sas",
|
||||
RelayEndpoint: "sb://endpoint/.net/liveshare",
|
||||
}
|
||||
uri, err := opts.uri("connect")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" {
|
||||
t.Errorf("uri is not correct, got: '%v'", uri)
|
||||
}
|
||||
}
|
||||
162
internal/liveshare/port_forwarder.go
Normal file
162
internal/liveshare/port_forwarder.go
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
|
||||
// container to a local destination such as a network port or Go reader/writer.
|
||||
type PortForwarder struct {
|
||||
session *Session
|
||||
name string
|
||||
remotePort int
|
||||
}
|
||||
|
||||
// NewPortForwarder returns a new PortForwarder for the specified
|
||||
// remote port and Live Share session. The name describes the purpose
|
||||
// of the remote port or service.
|
||||
func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder {
|
||||
return &PortForwarder{
|
||||
session: session,
|
||||
name: name,
|
||||
remotePort: remotePort,
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardToListener forwards traffic between the container's remote
|
||||
// port and a local port, which must already be listening for
|
||||
// connections. (Accepting a listener rather than a port number avoids
|
||||
// races against other processes opening ports, and against a client
|
||||
// connecting to the socket prematurely.)
|
||||
//
|
||||
// ForwardToListener accepts and handles connections on the local port
|
||||
// until it encounters the first error, which may include context
|
||||
// cancellation. Its error result is always non-nil. The caller is
|
||||
// responsible for closing the listening port.
|
||||
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) {
|
||||
id, err := fwd.shareRemotePort(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
errc := make(chan error, 1)
|
||||
sendError := func(err error) {
|
||||
// Use non-blocking send, to avoid goroutines getting
|
||||
// stuck in case of concurrent or sequential errors.
|
||||
select {
|
||||
case errc <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listen.Accept()
|
||||
if err != nil {
|
||||
sendError(err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := fwd.handleConnection(ctx, id, conn); err != nil {
|
||||
sendError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
// Forward forwards traffic between the container's remote port and
|
||||
// the specified read/write stream. On return, the stream is closed.
|
||||
func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error {
|
||||
id, err := fwd.shareRemotePort(ctx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Create buffered channel so that send doesn't get stuck after context cancellation.
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
errc <- fwd.handleConnection(ctx, id, conn)
|
||||
}()
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) {
|
||||
id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err)
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
func awaitError(ctx context.Context, errc <-chan error) error {
|
||||
select {
|
||||
case err := <-errc:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err() // canceled
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles forwarding for a single accepted connection, then closes it.
|
||||
func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
|
||||
defer span.Finish()
|
||||
|
||||
defer safeClose(conn, &err)
|
||||
|
||||
channel, err := fwd.session.openStreamingChannel(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening streaming channel for new connection: %w", err)
|
||||
}
|
||||
// Ideally we would call safeClose again, but (*ssh.channel).Close
|
||||
// appears to have a bug that causes it return io.EOF spuriously
|
||||
// if its peer closed first; see github.com/golang/go/issues/38115.
|
||||
defer func() {
|
||||
closeErr := channel.Close()
|
||||
if err == nil && closeErr != io.EOF {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
// bi-directional copy of data.
|
||||
errs := make(chan error, 2)
|
||||
copyConn := func(w io.Writer, r io.Reader) {
|
||||
_, err := io.Copy(w, r)
|
||||
errs <- err
|
||||
}
|
||||
go copyConn(conn, channel)
|
||||
go copyConn(channel, conn)
|
||||
|
||||
// Wait until context is cancelled or both copies are done.
|
||||
// Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail.
|
||||
// TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file?
|
||||
for i := 0; ; {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-errs:
|
||||
i++
|
||||
if i == 2 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// safeClose reports the error (to *err) from closing the stream only
|
||||
// if no other error was previously reported.
|
||||
func safeClose(closer io.Closer, err *error) {
|
||||
closeErr := closer.Close()
|
||||
if *err == nil {
|
||||
*err = closeErr
|
||||
}
|
||||
}
|
||||
95
internal/liveshare/port_forwarder_test.go
Normal file
95
internal/liveshare/port_forwarder_test.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
livesharetest "github.com/cli/cli/v2/internal/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestNewPortForwarder(t *testing.T) {
|
||||
testServer, session, err := makeMockSession()
|
||||
if err != nil {
|
||||
t.Errorf("create mock client: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
pf := NewPortForwarder(session, "ssh", 80)
|
||||
if pf == nil {
|
||||
t.Error("port forwarder is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForwarderStart(t *testing.T) {
|
||||
streamName, streamCondition := "stream-name", "stream-condition"
|
||||
serverSharing := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
|
||||
}
|
||||
getStream := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return "stream-id", nil
|
||||
}
|
||||
|
||||
stream := bytes.NewBufferString("stream-data")
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.startSharing", serverSharing),
|
||||
livesharetest.WithService("streamManager.getStream", getStream),
|
||||
livesharetest.WithStream("stream-id", stream),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("create mock session: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
|
||||
listen, err := net.Listen("tcp", ":8000")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listen.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
const name, remote = "ssh", 8000
|
||||
done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
var conn net.Conn
|
||||
retries := 0
|
||||
for conn == nil && retries < 2 {
|
||||
conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
if conn == nil {
|
||||
done <- errors.New("failed to connect to forwarded port")
|
||||
}
|
||||
b := make([]byte, len("stream-data"))
|
||||
if _, err := conn.Read(b); err != nil && err != io.EOF {
|
||||
done <- fmt.Errorf("reading stream: %w", err)
|
||||
}
|
||||
if string(b) != "stream-data" {
|
||||
done <- fmt.Errorf("stream data is not expected value, got: %s", string(b))
|
||||
}
|
||||
if _, err := conn.Write([]byte("new-data")); err != nil {
|
||||
done <- fmt.Errorf("writing to stream: %w", err)
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
41
internal/liveshare/rpc.go
Normal file
41
internal/liveshare/rpc.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
type rpcClient struct {
|
||||
*jsonrpc2.Conn
|
||||
conn io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func newRPCClient(conn io.ReadWriteCloser) *rpcClient {
|
||||
return &rpcClient{conn: conn}
|
||||
}
|
||||
|
||||
func (r *rpcClient) connect(ctx context.Context) {
|
||||
stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{})
|
||||
r.Conn = jsonrpc2.NewConn(ctx, stream, nullHandler{})
|
||||
}
|
||||
|
||||
func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, method)
|
||||
defer span.Finish()
|
||||
|
||||
waiter, err := r.Conn.DispatchCall(ctx, method, args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error dispatching %q call: %w", method, err)
|
||||
}
|
||||
|
||||
return waiter.Wait(ctx, result)
|
||||
}
|
||||
|
||||
type nullHandler struct{}
|
||||
|
||||
func (nullHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
|
||||
}
|
||||
99
internal/liveshare/session.go
Normal file
99
internal/liveshare/session.go
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// A Session represents the session between a connected Live Share client and server.
|
||||
type Session struct {
|
||||
ssh *sshSession
|
||||
rpc *rpcClient
|
||||
}
|
||||
|
||||
// Close should be called by users to clean up RPC and SSH resources whenever the session
|
||||
// is no longer active.
|
||||
func (s *Session) Close() error {
|
||||
// Closing the RPC conn closes the underlying stream (SSH)
|
||||
// So we only need to close once
|
||||
if err := s.rpc.Close(); err != nil {
|
||||
s.ssh.Close() // close SSH and ignore error
|
||||
return fmt.Errorf("error while closing Live Share session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Port describes a port exposed by the container.
|
||||
type Port struct {
|
||||
SourcePort int `json:"sourcePort"`
|
||||
DestinationPort int `json:"destinationPort"`
|
||||
SessionName string `json:"sessionName"`
|
||||
StreamName string `json:"streamName"`
|
||||
StreamCondition string `json:"streamCondition"`
|
||||
BrowseURL string `json:"browseUrl"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"`
|
||||
HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"`
|
||||
}
|
||||
|
||||
// startSharing tells the Live Share host to start sharing the specified port from the container.
|
||||
// The sessionName describes the purpose of the remote port or service.
|
||||
// It returns an identifier that can be used to open an SSH channel to the remote port.
|
||||
func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) {
|
||||
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
|
||||
var response Port
|
||||
if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil {
|
||||
return channelID{}, err
|
||||
}
|
||||
|
||||
return channelID{response.StreamName, response.StreamCondition}, nil
|
||||
}
|
||||
|
||||
// GetSharedServers returns a description of each container port
|
||||
// shared by a prior call to StartSharing by some client.
|
||||
func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) {
|
||||
var response []*Port
|
||||
if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly
|
||||
// via the Browse URL
|
||||
func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public bool) error {
|
||||
if err := s.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartsSSHServer starts an SSH server in the container, installing sshd if necessary,
|
||||
// and returns the port on which it listens and the user name clients should provide.
|
||||
func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) {
|
||||
var response struct {
|
||||
Result bool `json:"result"`
|
||||
ServerPort string `json:"serverPort"`
|
||||
User string `json:"user"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
if !response.Result {
|
||||
return 0, "", fmt.Errorf("failed to start server: %s", response.Message)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(response.ServerPort)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("failed to parse port: %w", err)
|
||||
}
|
||||
|
||||
return port, response.User, nil
|
||||
}
|
||||
223
internal/liveshare/session_test.go
Normal file
223
internal/liveshare/session_test.go
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
livesharetest "github.com/cli/cli/v2/internal/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
const sessionToken = "session-token"
|
||||
opts = append(
|
||||
opts,
|
||||
livesharetest.WithPassword(sessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
)
|
||||
testServer, err := livesharetest.NewServer(opts...)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating server: %w", err)
|
||||
}
|
||||
|
||||
session, err := Connect(context.Background(), Options{
|
||||
SessionID: "session-id",
|
||||
SessionToken: sessionToken,
|
||||
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
|
||||
RelaySAS: "relay-sas",
|
||||
HostPublicKeys: []string{livesharetest.SSHPublicKey},
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
}
|
||||
return testServer, session, nil
|
||||
}
|
||||
|
||||
func TestServerStartSharing(t *testing.T) {
|
||||
serverPort, serverProtocol := 2222, "sshd"
|
||||
startSharing := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
var args []interface{}
|
||||
if err := json.Unmarshal(*req.Params, &args); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling request: %w", err)
|
||||
}
|
||||
if len(args) < 3 {
|
||||
return nil, errors.New("not enough arguments to start sharing")
|
||||
}
|
||||
if port, ok := args[0].(float64); !ok {
|
||||
return nil, errors.New("port argument is not an int")
|
||||
} else if port != float64(serverPort) {
|
||||
return nil, errors.New("port does not match serverPort")
|
||||
}
|
||||
if protocol, ok := args[1].(string); !ok {
|
||||
return nil, errors.New("protocol argument is not a string")
|
||||
} else if protocol != serverProtocol {
|
||||
return nil, errors.New("protocol does not match serverProtocol")
|
||||
}
|
||||
if browseURL, ok := args[2].(string); !ok {
|
||||
return nil, errors.New("browse url is not a string")
|
||||
} else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) {
|
||||
return nil, errors.New("browseURL does not match expected")
|
||||
}
|
||||
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
|
||||
}
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.startSharing", startSharing),
|
||||
)
|
||||
defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error creating mock session: %w", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
streamID, err := session.startSharing(ctx, serverProtocol, serverPort)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("error sharing server: %w", err)
|
||||
}
|
||||
if streamID.name == "" || streamID.condition == "" {
|
||||
done <- errors.New("stream name or condition is blank")
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerGetSharedServers(t *testing.T) {
|
||||
sharedServer := Port{
|
||||
SourcePort: 2222,
|
||||
StreamName: "stream-name",
|
||||
StreamCondition: "stream-condition",
|
||||
}
|
||||
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return []*Port{&sharedServer}, nil
|
||||
}
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("error creating mock session: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
ctx := context.Background()
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
ports, err := session.GetSharedServers(ctx)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("error getting shared servers: %w", err)
|
||||
}
|
||||
if len(ports) < 1 {
|
||||
done <- errors.New("not enough ports returned")
|
||||
}
|
||||
if ports[0].SourcePort != sharedServer.SourcePort {
|
||||
done <- errors.New("source port does not match")
|
||||
}
|
||||
if ports[0].StreamName != sharedServer.StreamName {
|
||||
done <- errors.New("stream name does not match")
|
||||
}
|
||||
if ports[0].StreamCondition != sharedServer.StreamCondition {
|
||||
done <- errors.New("stream condiion does not match")
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerUpdateSharedVisibility(t *testing.T) {
|
||||
updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
var req []interface{}
|
||||
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal req: %w", err)
|
||||
}
|
||||
if len(req) < 2 {
|
||||
return nil, errors.New("request arguments is less than 2")
|
||||
}
|
||||
if port, ok := req[0].(float64); ok {
|
||||
if port != 80.0 {
|
||||
return nil, errors.New("port param is not expected value")
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("port param is not a float64")
|
||||
}
|
||||
if public, ok := req[1].(bool); ok {
|
||||
if public != true {
|
||||
return nil, errors.New("pulic param is not expected value")
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("public param is not a bool")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("creating mock session: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
ctx := context.Background()
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
done <- session.UpdateSharedVisibility(ctx, 80, true)
|
||||
}()
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidHostKey(t *testing.T) {
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
const sessionToken = "session-token"
|
||||
opts := []livesharetest.ServerOption{
|
||||
livesharetest.WithPassword(sessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
}
|
||||
testServer, err := livesharetest.NewServer(opts...)
|
||||
if err != nil {
|
||||
t.Errorf("error creating server: %w", err)
|
||||
}
|
||||
_, err = Connect(context.Background(), Options{
|
||||
SessionID: "session-id",
|
||||
SessionToken: sessionToken,
|
||||
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
|
||||
RelaySAS: "relay-sas",
|
||||
HostPublicKeys: []string{},
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected invalid host key error, got: nil")
|
||||
}
|
||||
}
|
||||
100
internal/liveshare/socket.go
Normal file
100
internal/liveshare/socket.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type socket struct {
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
|
||||
conn *websocket.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newSocket(uri string, tlsConfig *tls.Config) *socket {
|
||||
return &socket{addr: uri, tlsConfig: tlsConfig}
|
||||
}
|
||||
|
||||
func (s *socket) connect(ctx context.Context) error {
|
||||
dialer := websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: 45 * time.Second,
|
||||
TLSClientConfig: s.tlsConfig,
|
||||
}
|
||||
ws, _, err := dialer.Dial(s.addr, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.conn = ws
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *socket) Read(b []byte) (int, error) {
|
||||
if s.reader == nil {
|
||||
_, reader, err := s.conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.reader = reader
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socket) Write(b []byte) (int, error) {
|
||||
nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
bytesWritten, err := nextWriter.Write(b)
|
||||
nextWriter.Close()
|
||||
|
||||
return bytesWritten, err
|
||||
}
|
||||
|
||||
func (s *socket) Close() error {
|
||||
return s.conn.Close()
|
||||
}
|
||||
|
||||
func (s *socket) LocalAddr() net.Addr {
|
||||
return s.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (s *socket) RemoteAddr() net.Addr {
|
||||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (s *socket) SetDeadline(t time.Time) error {
|
||||
if err := s.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (s *socket) SetReadDeadline(t time.Time) error {
|
||||
return s.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (s *socket) SetWriteDeadline(t time.Time) error {
|
||||
return s.conn.SetWriteDeadline(t)
|
||||
}
|
||||
79
internal/liveshare/ssh.go
Normal file
79
internal/liveshare/ssh.go
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type sshSession struct {
|
||||
*ssh.Session
|
||||
token string
|
||||
hostPublicKeys []string
|
||||
socket net.Conn
|
||||
conn ssh.Conn
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func newSSHSession(token string, hostPublicKeys []string, socket net.Conn) *sshSession {
|
||||
return &sshSession{token: token, hostPublicKeys: hostPublicKeys, socket: socket}
|
||||
}
|
||||
|
||||
func (s *sshSession) connect(ctx context.Context) error {
|
||||
clientConfig := ssh.ClientConfig{
|
||||
User: "",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password(s.token),
|
||||
},
|
||||
HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"},
|
||||
HostKeyCallback: func(hostname string, addr net.Addr, key ssh.PublicKey) error {
|
||||
encodedKey := base64.StdEncoding.EncodeToString(key.Marshal())
|
||||
for _, hpk := range s.hostPublicKeys {
|
||||
if encodedKey == hpk {
|
||||
return nil // we found a match for expected public key, safely return
|
||||
}
|
||||
}
|
||||
return errors.New("invalid host public key")
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh client connection: %w", err)
|
||||
}
|
||||
s.conn = sshClientConn
|
||||
|
||||
sshClient := ssh.NewClient(sshClientConn, chans, reqs)
|
||||
s.Session, err = sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh client session: %w", err)
|
||||
}
|
||||
|
||||
s.reader, err = s.Session.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh session reader: %w", err)
|
||||
}
|
||||
|
||||
s.writer, err = s.Session.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh session writer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sshSession) Read(p []byte) (n int, err error) {
|
||||
return s.reader.Read(p)
|
||||
}
|
||||
|
||||
func (s *sshSession) Write(p []byte) (n int, err error) {
|
||||
return s.writer.Write(p)
|
||||
}
|
||||
334
internal/liveshare/test/server.go
Normal file
334
internal/liveshare/test/server.go
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
package livesharetest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXgIBAAKBgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yV
|
||||
rCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhY
|
||||
lR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3wIDAQAB
|
||||
AoGBAI8UemkYoSM06gBCh5D1RHQt8eKNltzL7g9QSNfoXeZOC7+q+/TiZPcbqLp0
|
||||
5lyOalu8b8Ym7J0rSE377Ypj13LyHMXS63e4wMiXv3qOl3GDhMLpypnJ8PwqR2b8
|
||||
IijL2jrpQfLu6IYqlteA+7e9aEexJa1RRwxYIyq6pG1IYpbhAkEA9nKgtj3Z6ZDC
|
||||
46IdqYzuUM9ZQdcw4AFr407+lub7tbWe5pYmaq3cT725IwLw081OAmnWJYFDMa/n
|
||||
IPl9YcZSPQJBAMGOMbPs/YPkQAsgNdIUlFtK3o41OrrwJuTRTvv0DsbqDV0LKOiC
|
||||
t8oAQQvjisH6Ew5OOhFyIFXtvZfzQMJppksCQQDWFd+cUICTUEise/Duj9maY3Uz
|
||||
J99ySGnTbZTlu8PfJuXhg3/d3ihrMPG6A1z3cPqaSBxaOj8H07mhQHn1zNU1AkEA
|
||||
hkl+SGPrO793g4CUdq2ahIA8SpO5rIsDoQtq7jlUq0MlhGFCv5Y5pydn+bSjx5MV
|
||||
933kocf5kUSBntPBIWElYwJAZTm5ghu0JtSE6t3km0iuj7NGAQSdb6mD8+O7C3CP
|
||||
FU3vi+4HlBysaT6IZ/HG+/dBsr4gYp4LGuS7DbaLuYw/uw==
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
|
||||
const SSHPublicKey = `AAAAB3NzaC1yc2EAAAADAQABAAAAgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yVrCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhYlR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3w==`
|
||||
|
||||
// Server represents a LiveShare relay host server.
|
||||
type Server struct {
|
||||
password string
|
||||
services map[string]RPCHandleFunc
|
||||
relaySAS string
|
||||
streams map[string]io.ReadWriter
|
||||
sshConfig *ssh.ServerConfig
|
||||
httptestServer *httptest.Server
|
||||
errCh chan error
|
||||
}
|
||||
|
||||
// NewServer creates a new Server. ServerOptions can be passed to configure
|
||||
// the SSH password, backing service, secrets and more.
|
||||
func NewServer(opts ...ServerOption) (*Server, error) {
|
||||
server := new(Server)
|
||||
|
||||
for _, o := range opts {
|
||||
if err := o(server); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
server.sshConfig = &ssh.ServerConfig{
|
||||
PasswordCallback: sshPasswordCallback(server.password),
|
||||
}
|
||||
privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing key: %w", err)
|
||||
}
|
||||
server.sshConfig.AddHostKey(privateKey)
|
||||
|
||||
server.errCh = make(chan error, 1)
|
||||
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// ServerOption is used to configure the Server.
|
||||
type ServerOption func(*Server) error
|
||||
|
||||
// WithPassword configures the Server password for SSH.
|
||||
func WithPassword(password string) ServerOption {
|
||||
return func(s *Server) error {
|
||||
s.password = password
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithService accepts a mock RPC service for the Server to invoke.
|
||||
func WithService(serviceName string, handler RPCHandleFunc) ServerOption {
|
||||
return func(s *Server) error {
|
||||
if s.services == nil {
|
||||
s.services = make(map[string]RPCHandleFunc)
|
||||
}
|
||||
|
||||
s.services[serviceName] = handler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithRelaySAS configures the relay SAS configuration key.
|
||||
func WithRelaySAS(sas string) ServerOption {
|
||||
return func(s *Server) error {
|
||||
s.relaySAS = sas
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithStream allows you to specify a mock data stream for the server.
|
||||
func WithStream(name string, stream io.ReadWriter) ServerOption {
|
||||
return func(s *Server) error {
|
||||
if s.streams == nil {
|
||||
s.streams = make(map[string]io.ReadWriter)
|
||||
}
|
||||
s.streams[name] = stream
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
|
||||
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if string(password) == serverPassword {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.New("password rejected")
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the underlying httptest Server.
|
||||
func (s *Server) Close() {
|
||||
s.httptestServer.Close()
|
||||
}
|
||||
|
||||
// URL returns the httptest Server url.
|
||||
func (s *Server) URL() string {
|
||||
return s.httptestServer.URL
|
||||
}
|
||||
|
||||
func (s *Server) Err() <-chan error {
|
||||
return s.errCh
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
func makeConnection(server *Server) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if server.relaySAS != "" {
|
||||
// validate the sas key
|
||||
sasParam := req.URL.Query().Get("sb-hc-token")
|
||||
if sasParam != server.relaySAS {
|
||||
sendError(server.errCh, errors.New("error validating sas"))
|
||||
return
|
||||
}
|
||||
}
|
||||
c, err := upgrader.Upgrade(w, req, nil)
|
||||
if err != nil {
|
||||
sendError(server.errCh, fmt.Errorf("error upgrading connection: %w", err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
sendError(server.errCh, err)
|
||||
}
|
||||
}()
|
||||
|
||||
socketConn := newSocketConn(c)
|
||||
_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
|
||||
if err != nil {
|
||||
sendError(server.errCh, fmt.Errorf("error creating new ssh conn: %w", err))
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
if err := handleChannels(ctx, server, chans); err != nil {
|
||||
sendError(server.errCh, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendError does a non-blocking send of the error to the err channel.
|
||||
func sendError(errc chan<- error, err error) {
|
||||
select {
|
||||
case errc <- err:
|
||||
default:
|
||||
// channel is blocked with a previous error, so we ignore
|
||||
// this current error
|
||||
}
|
||||
}
|
||||
|
||||
// awaitError waits for the context to finish and returns its error (if any).
|
||||
// It also waits for an err to come through the err channel.
|
||||
func awaitError(ctx context.Context, errc <-chan error) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-errc:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// handleChannels services the sshChannels channel. For each SSH channel received
|
||||
// it creates a go routine to service the channel's requests. It returns on the first
|
||||
// error encountered.
|
||||
func handleChannels(ctx context.Context, server *Server, sshChannels <-chan ssh.NewChannel) error {
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
for sshCh := range sshChannels {
|
||||
ch, reqs, err := sshCh.Accept()
|
||||
if err != nil {
|
||||
sendError(errc, fmt.Errorf("failed to accept channel: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := handleRequests(ctx, server, ch, reqs); err != nil {
|
||||
sendError(errc, fmt.Errorf("failed to handle requests: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
handleChannel(server, ch)
|
||||
}
|
||||
}()
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
// handleRequests services the SSH channel requests channel. It replies to requests and
|
||||
// when stream transport requests are encountered, creates a go routine to create a
|
||||
// bi-directional data stream between the channel and server stream. It returns on the first error
|
||||
// encountered.
|
||||
func handleRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) error {
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
for req := range reqs {
|
||||
if req.WantReply {
|
||||
if err := req.Reply(true, nil); err != nil {
|
||||
sendError(errc, fmt.Errorf("error replying to channel request: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(req.Type, "stream-transport") {
|
||||
go func() {
|
||||
if err := forwardStream(ctx, server, req.Type, channel); err != nil {
|
||||
sendError(errc, fmt.Errorf("failed to forward stream: %w", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
// concurrentStream is a concurrency safe io.ReadWriter.
|
||||
type concurrentStream struct {
|
||||
sync.RWMutex
|
||||
stream io.ReadWriter
|
||||
}
|
||||
|
||||
func newConcurrentStream(rw io.ReadWriter) *concurrentStream {
|
||||
return &concurrentStream{stream: rw}
|
||||
}
|
||||
|
||||
func (cs *concurrentStream) Read(b []byte) (int, error) {
|
||||
cs.RLock()
|
||||
defer cs.RUnlock()
|
||||
return cs.stream.Read(b)
|
||||
}
|
||||
|
||||
func (cs *concurrentStream) Write(b []byte) (int, error) {
|
||||
cs.Lock()
|
||||
defer cs.Unlock()
|
||||
return cs.stream.Write(b)
|
||||
}
|
||||
|
||||
// forwardStream does a bi-directional copy of the stream <-> with the SSH channel. The io.Copy
|
||||
// runs until an error is encountered.
|
||||
func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) (err error) {
|
||||
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
|
||||
stream, found := server.streams[simpleStreamName]
|
||||
if !found {
|
||||
return fmt.Errorf("stream '%s' not found", simpleStreamName)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := channel.Close(); err == nil && closeErr != io.EOF {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
errc := make(chan error, 2)
|
||||
copy := func(dst io.Writer, src io.Reader) {
|
||||
if _, err := io.Copy(dst, src); err != nil {
|
||||
errc <- err
|
||||
}
|
||||
}
|
||||
|
||||
csStream := newConcurrentStream(stream)
|
||||
go copy(csStream, channel)
|
||||
go copy(channel, csStream)
|
||||
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
func handleChannel(server *Server, channel ssh.Channel) {
|
||||
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
|
||||
jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server))
|
||||
}
|
||||
|
||||
type RPCHandleFunc func(req *jsonrpc2.Request) (interface{}, error)
|
||||
|
||||
type rpcHandler struct {
|
||||
server *Server
|
||||
}
|
||||
|
||||
func newRPCHandler(server *Server) *rpcHandler {
|
||||
return &rpcHandler{server}
|
||||
}
|
||||
|
||||
// Handle satisfies the jsonrpc2 pkg handler interface. It tries to find a mocked
|
||||
// RPC service method and if found, it invokes the handler and replies to the request.
|
||||
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
|
||||
handler, found := r.server.services[req.Method]
|
||||
if !found {
|
||||
sendError(r.server.errCh, fmt.Errorf("RPC Method: '%s' not serviced", req.Method))
|
||||
return
|
||||
}
|
||||
|
||||
result, err := handler(req)
|
||||
if err != nil {
|
||||
sendError(r.server.errCh, fmt.Errorf("error handling: '%s': %w", req.Method, err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Reply(ctx, req.ID, result); err != nil {
|
||||
sendError(r.server.errCh, fmt.Errorf("error replying: %w", err))
|
||||
}
|
||||
}
|
||||
77
internal/liveshare/test/socket.go
Normal file
77
internal/liveshare/test/socket.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package livesharetest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type socketConn struct {
|
||||
*websocket.Conn
|
||||
|
||||
reader io.Reader
|
||||
writeMutex sync.Mutex
|
||||
readMutex sync.Mutex
|
||||
}
|
||||
|
||||
func newSocketConn(conn *websocket.Conn) *socketConn {
|
||||
return &socketConn{Conn: conn}
|
||||
}
|
||||
|
||||
func (s *socketConn) Read(b []byte) (int, error) {
|
||||
s.readMutex.Lock()
|
||||
defer s.readMutex.Unlock()
|
||||
|
||||
if s.reader == nil {
|
||||
msgType, r, err := s.Conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next reader: %w", err)
|
||||
}
|
||||
if msgType != websocket.BinaryMessage {
|
||||
return 0, fmt.Errorf("invalid message type")
|
||||
}
|
||||
s.reader = r
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socketConn) Write(b []byte) (int, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
|
||||
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next writer: %w", err)
|
||||
}
|
||||
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error writing: %w", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return 0, fmt.Errorf("error closing writer: %w", err)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *socketConn) SetDeadline(deadline time.Time) error {
|
||||
if err := s.Conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Conn.SetWriteDeadline(deadline)
|
||||
}
|
||||
|
|
@ -2,8 +2,12 @@ package root
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/MakeNowJust/heredoc"
|
||||
"github.com/cli/cli/v2/cmd/ghcs"
|
||||
"github.com/cli/cli/v2/cmd/ghcs/output"
|
||||
ghcsApi "github.com/cli/cli/v2/internal/api"
|
||||
actionsCmd "github.com/cli/cli/v2/pkg/cmd/actions"
|
||||
aliasCmd "github.com/cli/cli/v2/pkg/cmd/alias"
|
||||
apiCmd "github.com/cli/cli/v2/pkg/cmd/api"
|
||||
|
|
@ -78,6 +82,7 @@ func NewCmdRoot(f *cmdutil.Factory, version, buildDate string) *cobra.Command {
|
|||
cmd.AddCommand(extensionCmd.NewCmdExtension(f))
|
||||
cmd.AddCommand(secretCmd.NewCmdSecret(f))
|
||||
cmd.AddCommand(sshKeyCmd.NewCmdSSHKey(f))
|
||||
cmd.AddCommand(newCodespaceCmd(f))
|
||||
|
||||
// the `api` command should not inherit any extra HTTP headers
|
||||
bareHTTPCmdFactory := *f
|
||||
|
|
@ -121,3 +126,39 @@ func bareHTTPClient(f *cmdutil.Factory, version string) func() (*http.Client, er
|
|||
return factory.NewHTTPClient(f.IOStreams, cfg, version, false)
|
||||
}
|
||||
}
|
||||
|
||||
func newCodespaceCmd(f *cmdutil.Factory) *cobra.Command {
|
||||
cmd := ghcs.NewRootCmd(ghcs.NewApp(
|
||||
output.NewLogger(f.IOStreams.Out, f.IOStreams.ErrOut, !f.IOStreams.IsStdoutTTY()),
|
||||
ghcsApi.New("", &lazyLoadedHTTPClient{factory: f}),
|
||||
))
|
||||
cmd.Use = "codespace"
|
||||
cmd.Aliases = []string{"cs"}
|
||||
cmd.Hidden = true
|
||||
return cmd
|
||||
}
|
||||
|
||||
type lazyLoadedHTTPClient struct {
|
||||
factory *cmdutil.Factory
|
||||
|
||||
httpClientMu sync.RWMutex // guards httpClient
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func (l *lazyLoadedHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
l.httpClientMu.RLock()
|
||||
httpClient := l.httpClient
|
||||
l.httpClientMu.RUnlock()
|
||||
|
||||
if httpClient == nil {
|
||||
l.httpClientMu.Lock()
|
||||
defer l.httpClientMu.Unlock()
|
||||
|
||||
var err error
|
||||
l.httpClient, err = l.factory.HttpClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return l.httpClient.Do(req)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue