Merge remote-tracking branch 'ghcs/main' into import-codespaces

Co-authored-by: Jose Garcia <josebalius@github.com>
This commit is contained in:
Mislav Marohnić 2021-09-28 16:46:27 +02:00
commit e64607d07b
104 changed files with 26752 additions and 0 deletions

65
cmd/ghcs/code.go Normal file
View file

@ -0,0 +1,65 @@
package ghcs
import (
"context"
"fmt"
"net/url"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
)
func newCodeCmd(app *App) *cobra.Command {
var (
codespace string
useInsiders bool
)
codeCmd := &cobra.Command{
Use: "code",
Short: "Open a codespace in VS Code",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return app.VSCode(cmd.Context(), codespace, useInsiders)
},
}
codeCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace")
codeCmd.Flags().BoolVar(&useInsiders, "insiders", false, "Use the insiders version of VS Code")
return codeCmd
}
// VSCode opens a codespace in the local VS VSCode application.
func (a *App) VSCode(ctx context.Context, codespaceName string, useInsiders bool) error {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
if codespaceName == "" {
codespace, err := chooseCodespace(ctx, a.apiClient, user)
if err != nil {
if err == errNoCodespaces {
return err
}
return fmt.Errorf("error choosing codespace: %w", err)
}
codespaceName = codespace.Name
}
url := vscodeProtocolURL(codespaceName, useInsiders)
if err := open.Run(url); err != nil {
return fmt.Errorf("error opening vscode URL %s: %s. (Is VS Code installed?)", url, err)
}
return nil
}
func vscodeProtocolURL(codespaceName string, useInsiders bool) string {
application := "vscode"
if useInsiders {
application = "vscode-insiders"
}
return fmt.Sprintf("%s://github.codespaces/connect?name=%s", application, url.QueryEscape(codespaceName))
}

185
cmd/ghcs/common.go Normal file
View file

@ -0,0 +1,185 @@
package ghcs
// This file defines functions common to the entire ghcs command set.
import (
"context"
"errors"
"fmt"
"io"
"os"
"sort"
"github.com/AlecAivazis/survey/v2"
"github.com/AlecAivazis/survey/v2/terminal"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/spf13/cobra"
"golang.org/x/term"
)
type App struct {
apiClient apiClient
logger *output.Logger
}
func NewApp(logger *output.Logger, apiClient apiClient) *App {
return &App{
apiClient: apiClient,
logger: logger,
}
}
//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient
type apiClient interface {
GetUser(ctx context.Context) (*api.User, error)
GetCodespaceToken(ctx context.Context, user, name string) (string, error)
GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error)
ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error)
DeleteCodespace(ctx context.Context, user, name string) error
StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error
CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
GetRepository(ctx context.Context, nwo string) (*api.Repository, error)
AuthorizedKeys(ctx context.Context, user string) ([]byte, error)
GetCodespaceRegionLocation(ctx context.Context) (string, error)
GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch, location string) ([]*api.SKU, error)
GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
}
var errNoCodespaces = errors.New("you have no codespaces")
func chooseCodespace(ctx context.Context, apiClient apiClient, user *api.User) (*api.Codespace, error) {
codespaces, err := apiClient.ListCodespaces(ctx, user.Login)
if err != nil {
return nil, fmt.Errorf("error getting codespaces: %w", err)
}
return chooseCodespaceFromList(ctx, codespaces)
}
func chooseCodespaceFromList(ctx context.Context, codespaces []*api.Codespace) (*api.Codespace, error) {
if len(codespaces) == 0 {
return nil, errNoCodespaces
}
sort.Slice(codespaces, func(i, j int) bool {
return codespaces[i].CreatedAt > codespaces[j].CreatedAt
})
codespacesByName := make(map[string]*api.Codespace)
codespacesNames := make([]string, 0, len(codespaces))
for _, codespace := range codespaces {
codespacesByName[codespace.Name] = codespace
codespacesNames = append(codespacesNames, codespace.Name)
}
sshSurvey := []*survey.Question{
{
Name: "codespace",
Prompt: &survey.Select{
Message: "Choose codespace:",
Options: codespacesNames,
Default: codespacesNames[0],
},
Validate: survey.Required,
},
}
var answers struct {
Codespace string
}
if err := ask(sshSurvey, &answers); err != nil {
return nil, fmt.Errorf("error getting answers: %w", err)
}
codespace := codespacesByName[answers.Codespace]
return codespace, nil
}
// getOrChooseCodespace prompts the user to choose a codespace if the codespaceName is empty.
// It then fetches the codespace token and the codespace record.
func getOrChooseCodespace(ctx context.Context, apiClient apiClient, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) {
if codespaceName == "" {
codespace, err = chooseCodespace(ctx, apiClient, user)
if err != nil {
if err == errNoCodespaces {
return nil, "", err
}
return nil, "", fmt.Errorf("choosing codespace: %w", err)
}
codespaceName = codespace.Name
token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName)
if err != nil {
return nil, "", fmt.Errorf("getting codespace token: %w", err)
}
} else {
token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName)
if err != nil {
return nil, "", fmt.Errorf("getting codespace token for given codespace: %w", err)
}
codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName)
if err != nil {
return nil, "", fmt.Errorf("getting full codespace details: %w", err)
}
}
return codespace, token, nil
}
func safeClose(closer io.Closer, err *error) {
if closeErr := closer.Close(); *err == nil {
*err = closeErr
}
}
// hasTTY indicates whether the process connected to a terminal.
// It is not portable to assume stdin/stdout are fds 0 and 1.
var hasTTY = term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
// ask asks survey questions on the terminal, using standard options.
// It fails unless hasTTY, but ideally callers should avoid calling it in that case.
func ask(qs []*survey.Question, response interface{}) error {
if !hasTTY {
return fmt.Errorf("no terminal")
}
err := survey.Ask(qs, response, survey.WithShowCursor(true))
// The survey package temporarily clears the terminal's ISIG mode bit
// (see tcsetattr(3)) so the QUIT button (Ctrl-C) is reported as
// ASCII \x03 (ETX) instead of delivering SIGINT to the application.
// So we have to serve ourselves the SIGINT.
//
// https://github.com/AlecAivazis/survey/#why-isnt-ctrl-c-working
if err == terminal.InterruptErr {
self, _ := os.FindProcess(os.Getpid())
_ = self.Signal(os.Interrupt) // assumes POSIX
// Suspend the goroutine, to avoid a race between
// return from main and async delivery of INT signal.
select {}
}
return err
}
// checkAuthorizedKeys reports an error if the user has not registered any SSH keys;
// see https://github.com/github/ghcs/issues/166#issuecomment-921769703.
// The check is not required for security but it improves the error message.
func checkAuthorizedKeys(ctx context.Context, client apiClient, user string) error {
keys, err := client.AuthorizedKeys(ctx, user)
if err != nil {
return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err)
}
if len(keys) == 0 {
return fmt.Errorf("user %s has no GitHub-authorized SSH keys", user)
}
return nil // success
}
var ErrTooManyArgs = errors.New("the command accepts no arguments")
func noArgsConstraint(cmd *cobra.Command, args []string) error {
if len(args) > 0 {
return ErrTooManyArgs
}
return nil
}

297
cmd/ghcs/create.go Normal file
View file

@ -0,0 +1,297 @@
package ghcs
import (
"context"
"errors"
"fmt"
"os"
"strings"
"github.com/AlecAivazis/survey/v2"
"github.com/fatih/camelcase"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/github/ghcs/internal/codespaces"
"github.com/spf13/cobra"
)
type createOptions struct {
repo string
branch string
machine string
showStatus bool
}
func newCreateCmd(app *App) *cobra.Command {
opts := createOptions{}
createCmd := &cobra.Command{
Use: "create",
Short: "Create a codespace",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return app.Create(cmd.Context(), opts)
},
}
createCmd.Flags().StringVarP(&opts.repo, "repo", "r", "", "repository name with owner: user/repo")
createCmd.Flags().StringVarP(&opts.branch, "branch", "b", "", "repository branch")
createCmd.Flags().StringVarP(&opts.machine, "machine", "m", "", "hardware specifications for the VM")
createCmd.Flags().BoolVarP(&opts.showStatus, "status", "s", false, "show status of post-create command and dotfiles")
return createCmd
}
// Create creates a new Codespace
func (a *App) Create(ctx context.Context, opts createOptions) error {
locationCh := getLocation(ctx, a.apiClient)
userCh := getUser(ctx, a.apiClient)
repo, err := getRepoName(opts.repo)
if err != nil {
return fmt.Errorf("error getting repository name: %w", err)
}
branch, err := getBranchName(opts.branch)
if err != nil {
return fmt.Errorf("error getting branch name: %w", err)
}
repository, err := a.apiClient.GetRepository(ctx, repo)
if err != nil {
return fmt.Errorf("error getting repository: %w", err)
}
locationResult := <-locationCh
if locationResult.Err != nil {
return fmt.Errorf("error getting codespace region location: %w", locationResult.Err)
}
userResult := <-userCh
if userResult.Err != nil {
return fmt.Errorf("error getting codespace user: %w", userResult.Err)
}
machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, a.apiClient)
if err != nil {
return fmt.Errorf("error getting machine type: %w", err)
}
if machine == "" {
return errors.New("there are no available machine types for this repository")
}
a.logger.Print("Creating your codespace...")
codespace, err := a.apiClient.CreateCodespace(ctx, &api.CreateCodespaceParams{
User: userResult.User.Login,
RepositoryID: repository.ID,
Branch: branch,
Machine: machine,
Location: locationResult.Location,
})
a.logger.Print("\n")
if err != nil {
return fmt.Errorf("error creating codespace: %w", err)
}
if opts.showStatus {
if err := showStatus(ctx, a.logger, a.apiClient, userResult.User, codespace); err != nil {
return fmt.Errorf("show status: %w", err)
}
}
a.logger.Printf("Codespace created: ")
fmt.Fprintln(os.Stdout, codespace.Name)
return nil
}
// showStatus polls the codespace for a list of post create states and their status. It will keep polling
// until all states have finished. Once all states have finished, we poll once more to check if any new
// states have been introduced and stop polling otherwise.
func showStatus(ctx context.Context, log *output.Logger, apiClient apiClient, user *api.User, codespace *api.Codespace) error {
var lastState codespaces.PostCreateState
var breakNextState bool
finishedStates := make(map[string]bool)
ctx, stopPolling := context.WithCancel(ctx)
defer stopPolling()
poller := func(states []codespaces.PostCreateState) {
var inProgress bool
for _, state := range states {
if _, found := finishedStates[state.Name]; found {
continue // skip this state as we've processed it already
}
if state.Name != lastState.Name {
log.Print(state.Name)
if state.Status == codespaces.PostCreateStateRunning {
inProgress = true
lastState = state
log.Print("...")
break
}
finishedStates[state.Name] = true
log.Println("..." + state.Status)
} else {
if state.Status == codespaces.PostCreateStateRunning {
inProgress = true
log.Print(".")
break
}
finishedStates[state.Name] = true
log.Println(state.Status)
lastState = codespaces.PostCreateState{} // reset the value
}
}
if !inProgress {
if breakNextState {
stopPolling()
return
}
breakNextState = true
}
}
err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace, poller)
if err != nil {
if errors.Is(err, context.Canceled) && breakNextState {
return nil // we cancelled the context to stop polling, we can ignore the error
}
return fmt.Errorf("failed to poll state changes from codespace: %w", err)
}
return nil
}
type getUserResult struct {
User *api.User
Err error
}
// getUser fetches the user record associated with the GITHUB_TOKEN
func getUser(ctx context.Context, apiClient apiClient) <-chan getUserResult {
ch := make(chan getUserResult, 1)
go func() {
user, err := apiClient.GetUser(ctx)
ch <- getUserResult{user, err}
}()
return ch
}
type locationResult struct {
Location string
Err error
}
// getLocation fetches the closest Codespace datacenter region/location to the user.
func getLocation(ctx context.Context, apiClient apiClient) <-chan locationResult {
ch := make(chan locationResult, 1)
go func() {
location, err := apiClient.GetCodespaceRegionLocation(ctx)
ch <- locationResult{location, err}
}()
return ch
}
// getRepoName prompts the user for the name of the repository, or returns the repository if non-empty.
func getRepoName(repo string) (string, error) {
if repo != "" {
return repo, nil
}
repoSurvey := []*survey.Question{
{
Name: "repository",
Prompt: &survey.Input{Message: "Repository:"},
Validate: survey.Required,
},
}
err := ask(repoSurvey, &repo)
return repo, err
}
// getBranchName prompts the user for the name of the branch, or returns the branch if non-empty.
func getBranchName(branch string) (string, error) {
if branch != "" {
return branch, nil
}
branchSurvey := []*survey.Question{
{
Name: "branch",
Prompt: &survey.Input{Message: "Branch:"},
Validate: survey.Required,
},
}
err := ask(branchSurvey, &branch)
return branch, err
}
// getMachineName prompts the user to select the machine type, or validates the machine if non-empty.
func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient apiClient) (string, error) {
skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location)
if err != nil {
return "", fmt.Errorf("error requesting machine instance types: %w", err)
}
// if user supplied a machine type, it must be valid
// if no machine type was supplied, we don't error if there are no machine types for the current repo
if machine != "" {
for _, sku := range skus {
if machine == sku.Name {
return machine, nil
}
}
availableSKUs := make([]string, len(skus))
for i := 0; i < len(skus); i++ {
availableSKUs[i] = skus[i].Name
}
return "", fmt.Errorf("there is no such machine for the repository: %s\nAvailable machines: %v", machine, availableSKUs)
} else if len(skus) == 0 {
return "", nil
}
if len(skus) == 1 {
return skus[0].Name, nil // VS Code does not prompt for SKU if there is only one, this makes us consistent with that behavior
}
skuNames := make([]string, 0, len(skus))
skuByName := make(map[string]*api.SKU)
for _, sku := range skus {
nameParts := camelcase.Split(sku.Name)
machineName := strings.Title(strings.ToLower(nameParts[0]))
skuName := fmt.Sprintf("%s - %s", machineName, sku.DisplayName)
skuNames = append(skuNames, skuName)
skuByName[skuName] = sku
}
skuSurvey := []*survey.Question{
{
Name: "sku",
Prompt: &survey.Select{
Message: "Choose Machine Type:",
Options: skuNames,
Default: skuNames[0],
},
Validate: survey.Required,
},
}
var skuAnswers struct{ SKU string }
if err := ask(skuSurvey, &skuAnswers); err != nil {
return "", fmt.Errorf("error getting SKU: %w", err)
}
sku := skuByName[skuAnswers.SKU]
machine = sku.Name
return machine, nil
}

180
cmd/ghcs/delete.go Normal file
View file

@ -0,0 +1,180 @@
package ghcs
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/AlecAivazis/survey/v2"
"github.com/github/ghcs/internal/api"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
type deleteOptions struct {
deleteAll bool
skipConfirm bool
codespaceName string
repoFilter string
keepDays uint16
isInteractive bool
now func() time.Time
prompter prompter
}
//go:generate moq -fmt goimports -rm -skip-ensure -out mock_prompter.go . prompter
type prompter interface {
Confirm(message string) (bool, error)
}
func newDeleteCmd(app *App) *cobra.Command {
opts := deleteOptions{
isInteractive: hasTTY,
now: time.Now,
prompter: &surveyPrompter{},
}
deleteCmd := &cobra.Command{
Use: "delete",
Short: "Delete a codespace",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
if opts.deleteAll && opts.repoFilter != "" {
return errors.New("both --all and --repo is not supported")
}
return app.Delete(cmd.Context(), opts)
},
}
deleteCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "Name of the codespace")
deleteCmd.Flags().BoolVar(&opts.deleteAll, "all", false, "Delete all codespaces")
deleteCmd.Flags().StringVarP(&opts.repoFilter, "repo", "r", "", "Delete codespaces for a `repository`")
deleteCmd.Flags().BoolVarP(&opts.skipConfirm, "force", "f", false, "Skip confirmation for codespaces that contain unsaved changes")
deleteCmd.Flags().Uint16Var(&opts.keepDays, "days", 0, "Delete codespaces older than `N` days")
return deleteCmd
}
func (a *App) Delete(ctx context.Context, opts deleteOptions) error {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
var codespaces []*api.Codespace
nameFilter := opts.codespaceName
if nameFilter == "" {
codespaces, err = a.apiClient.ListCodespaces(ctx, user.Login)
if err != nil {
return fmt.Errorf("error getting codespaces: %w", err)
}
if !opts.deleteAll && opts.repoFilter == "" {
c, err := chooseCodespaceFromList(ctx, codespaces)
if err != nil {
return fmt.Errorf("error choosing codespace: %w", err)
}
nameFilter = c.Name
}
} else {
// TODO: this token is discarded and then re-requested later in DeleteCodespace
token, err := a.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter)
if err != nil {
return fmt.Errorf("error getting codespace token: %w", err)
}
codespace, err := a.apiClient.GetCodespace(ctx, token, user.Login, nameFilter)
if err != nil {
return fmt.Errorf("error fetching codespace information: %w", err)
}
codespaces = []*api.Codespace{codespace}
}
codespacesToDelete := make([]*api.Codespace, 0, len(codespaces))
lastUpdatedCutoffTime := opts.now().AddDate(0, 0, -int(opts.keepDays))
for _, c := range codespaces {
if nameFilter != "" && c.Name != nameFilter {
continue
}
if opts.repoFilter != "" && !strings.EqualFold(c.RepositoryNWO, opts.repoFilter) {
continue
}
if opts.keepDays > 0 {
t, err := time.Parse(time.RFC3339, c.LastUsedAt)
if err != nil {
return fmt.Errorf("error parsing last_used_at timestamp %q: %w", c.LastUsedAt, err)
}
if t.After(lastUpdatedCutoffTime) {
continue
}
}
if !opts.skipConfirm {
confirmed, err := confirmDeletion(opts.prompter, c, opts.isInteractive)
if err != nil {
return fmt.Errorf("unable to confirm: %w", err)
}
if !confirmed {
continue
}
}
codespacesToDelete = append(codespacesToDelete, c)
}
if len(codespacesToDelete) == 0 {
return errors.New("no codespaces to delete")
}
g := errgroup.Group{}
for _, c := range codespacesToDelete {
codespaceName := c.Name
g.Go(func() error {
if err := a.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil {
_, _ = a.logger.Errorf("error deleting codespace %q: %v\n", codespaceName, err)
return err
}
return nil
})
}
if err := g.Wait(); err != nil {
return errors.New("some codespaces failed to delete")
}
return nil
}
func confirmDeletion(p prompter, codespace *api.Codespace, isInteractive bool) (bool, error) {
gs := codespace.Environment.GitStatus
hasUnsavedChanges := gs.HasUncommitedChanges || gs.HasUnpushedChanges
if !hasUnsavedChanges {
return true, nil
}
if !isInteractive {
return false, fmt.Errorf("codespace %s has unsaved changes (use --force to override)", codespace.Name)
}
return p.Confirm(fmt.Sprintf("Codespace %s has unsaved changes. OK to delete?", codespace.Name))
}
type surveyPrompter struct{}
func (p *surveyPrompter) Confirm(message string) (bool, error) {
var confirmed struct {
Confirmed bool
}
q := []*survey.Question{
{
Name: "confirmed",
Prompt: &survey.Confirm{
Message: message,
},
},
}
if err := ask(q, &confirmed); err != nil {
return false, fmt.Errorf("failed to prompt: %w", err)
}
return confirmed.Confirmed, nil
}

241
cmd/ghcs/delete_test.go Normal file
View file

@ -0,0 +1,241 @@
package ghcs
import (
"bytes"
"context"
"errors"
"fmt"
"sort"
"testing"
"time"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
)
func TestDelete(t *testing.T) {
user := &api.User{Login: "hubot"}
now, _ := time.Parse(time.RFC3339, "2021-09-22T00:00:00Z")
daysAgo := func(n int) string {
return now.Add(time.Hour * -time.Duration(24*n)).Format(time.RFC3339)
}
tests := []struct {
name string
opts deleteOptions
codespaces []*api.Codespace
confirms map[string]bool
deleteErr error
wantErr bool
wantDeleted []string
wantStdout string
wantStderr string
}{
{
name: "by name",
opts: deleteOptions{
codespaceName: "hubot-robawt-abc",
},
codespaces: []*api.Codespace{
{
Name: "hubot-robawt-abc",
},
},
wantDeleted: []string{"hubot-robawt-abc"},
},
{
name: "by repo",
opts: deleteOptions{
repoFilter: "monalisa/spoon-knife",
},
codespaces: []*api.Codespace{
{
Name: "monalisa-spoonknife-123",
RepositoryNWO: "monalisa/Spoon-Knife",
},
{
Name: "hubot-robawt-abc",
RepositoryNWO: "hubot/ROBAWT",
},
{
Name: "monalisa-spoonknife-c4f3",
RepositoryNWO: "monalisa/Spoon-Knife",
},
},
wantDeleted: []string{"monalisa-spoonknife-123", "monalisa-spoonknife-c4f3"},
},
{
name: "unused",
opts: deleteOptions{
deleteAll: true,
keepDays: 3,
},
codespaces: []*api.Codespace{
{
Name: "monalisa-spoonknife-123",
LastUsedAt: daysAgo(1),
},
{
Name: "hubot-robawt-abc",
LastUsedAt: daysAgo(4),
},
{
Name: "monalisa-spoonknife-c4f3",
LastUsedAt: daysAgo(10),
},
},
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"},
},
{
name: "deletion failed",
opts: deleteOptions{
deleteAll: true,
},
codespaces: []*api.Codespace{
{
Name: "monalisa-spoonknife-123",
},
{
Name: "hubot-robawt-abc",
},
},
deleteErr: errors.New("aborted by test"),
wantErr: true,
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-123"},
wantStderr: "error deleting codespace \"hubot-robawt-abc\": aborted by test\nerror deleting codespace \"monalisa-spoonknife-123\": aborted by test\n",
},
{
name: "with confirm",
opts: deleteOptions{
isInteractive: true,
deleteAll: true,
skipConfirm: false,
},
codespaces: []*api.Codespace{
{
Name: "monalisa-spoonknife-123",
Environment: api.CodespaceEnvironment{
GitStatus: api.CodespaceEnvironmentGitStatus{
HasUnpushedChanges: true,
},
},
},
{
Name: "hubot-robawt-abc",
Environment: api.CodespaceEnvironment{
GitStatus: api.CodespaceEnvironmentGitStatus{
HasUncommitedChanges: true,
},
},
},
{
Name: "monalisa-spoonknife-c4f3",
Environment: api.CodespaceEnvironment{
GitStatus: api.CodespaceEnvironmentGitStatus{
HasUnpushedChanges: false,
HasUncommitedChanges: false,
},
},
},
},
confirms: map[string]bool{
"Codespace monalisa-spoonknife-123 has unsaved changes. OK to delete?": false,
"Codespace hubot-robawt-abc has unsaved changes. OK to delete?": true,
},
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
apiMock := &apiClientMock{
GetUserFunc: func(_ context.Context) (*api.User, error) {
return user, nil
},
DeleteCodespaceFunc: func(_ context.Context, userLogin, name string) error {
if userLogin != user.Login {
return fmt.Errorf("unexpected user %q", userLogin)
}
if tt.deleteErr != nil {
return tt.deleteErr
}
return nil
},
}
if tt.opts.codespaceName == "" {
apiMock.ListCodespacesFunc = func(_ context.Context, userLogin string) ([]*api.Codespace, error) {
if userLogin != user.Login {
return nil, fmt.Errorf("unexpected user %q", userLogin)
}
return tt.codespaces, nil
}
} else {
apiMock.GetCodespaceTokenFunc = func(_ context.Context, userLogin, name string) (string, error) {
if userLogin != user.Login {
return "", fmt.Errorf("unexpected user %q", userLogin)
}
return "CS_TOKEN", nil
}
apiMock.GetCodespaceFunc = func(_ context.Context, token, userLogin, name string) (*api.Codespace, error) {
if userLogin != user.Login {
return nil, fmt.Errorf("unexpected user %q", userLogin)
}
if token != "CS_TOKEN" {
return nil, fmt.Errorf("unexpected token %q", token)
}
return tt.codespaces[0], nil
}
}
opts := tt.opts
opts.now = func() time.Time { return now }
opts.prompter = &prompterMock{
ConfirmFunc: func(msg string) (bool, error) {
res, found := tt.confirms[msg]
if !found {
return false, fmt.Errorf("unexpected prompt %q", msg)
}
return res, nil
},
}
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
app := &App{
apiClient: apiMock,
logger: output.NewLogger(stdout, stderr, false),
}
err := app.Delete(context.Background(), opts)
if (err != nil) != tt.wantErr {
t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr)
}
if n := len(apiMock.GetUserCalls()); n != 1 {
t.Errorf("GetUser invoked %d times, expected %d", n, 1)
}
var gotDeleted []string
for _, delArgs := range apiMock.DeleteCodespaceCalls() {
gotDeleted = append(gotDeleted, delArgs.Name)
}
sort.Strings(gotDeleted)
if !sliceEquals(gotDeleted, tt.wantDeleted) {
t.Errorf("deleted %q, want %q", gotDeleted, tt.wantDeleted)
}
if out := stdout.String(); out != tt.wantStdout {
t.Errorf("stdout = %q, want %q", out, tt.wantStdout)
}
if out := stderr.String(); out != tt.wantStderr {
t.Errorf("stderr = %q, want %q", out, tt.wantStderr)
}
})
}
}
func sliceEquals(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

63
cmd/ghcs/list.go Normal file
View file

@ -0,0 +1,63 @@
package ghcs
import (
"context"
"fmt"
"os"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/spf13/cobra"
)
func newListCmd(app *App) *cobra.Command {
var asJSON bool
listCmd := &cobra.Command{
Use: "list",
Short: "List your codespaces",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return app.List(cmd.Context(), asJSON)
},
}
listCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON")
return listCmd
}
func (a *App) List(ctx context.Context, asJSON bool) error {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespaces, err := a.apiClient.ListCodespaces(ctx, user.Login)
if err != nil {
return fmt.Errorf("error getting codespaces: %w", err)
}
table := output.NewTable(os.Stdout, asJSON)
table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"})
for _, codespace := range codespaces {
table.Append([]string{
codespace.Name,
codespace.RepositoryNWO,
codespace.Branch + dirtyStar(codespace.Environment.GitStatus),
codespace.Environment.State,
codespace.CreatedAt,
})
}
table.Render()
return nil
}
func dirtyStar(status api.CodespaceEnvironmentGitStatus) string {
if status.HasUncommitedChanges || status.HasUnpushedChanges {
return "*"
}
return ""
}

112
cmd/ghcs/logs.go Normal file
View file

@ -0,0 +1,112 @@
package ghcs
import (
"context"
"fmt"
"net"
"github.com/github/ghcs/internal/codespaces"
"github.com/github/ghcs/internal/liveshare"
"github.com/spf13/cobra"
)
func newLogsCmd(app *App) *cobra.Command {
var (
codespace string
follow bool
)
logsCmd := &cobra.Command{
Use: "logs",
Short: "Access codespace logs",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return app.Logs(cmd.Context(), codespace, follow)
},
}
logsCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace")
logsCmd.Flags().BoolVarP(&follow, "follow", "f", false, "Tail and follow the logs")
return logsCmd
}
func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err error) {
// Ensure all child tasks (port forwarding, remote exec) terminate before return.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("getting user: %w", err)
}
authkeys := make(chan error, 1)
go func() {
authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login)
}()
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
return fmt.Errorf("get or choose codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
if err := <-authkeys; err != nil {
return err
}
// Ensure local port is listening before client (getPostCreateOutput) connects.
listen, err := net.Listen("tcp", ":0") // arbitrary port
if err != nil {
return err
}
defer listen.Close()
localPort := listen.Addr().(*net.TCPAddr).Port
a.logger.Println("Fetching SSH Details...")
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)
}
cmdType := "cat"
if follow {
cmdType = "tail -f"
}
dst := fmt.Sprintf("%s@localhost", sshUser)
cmd, err := codespaces.NewRemoteCommand(
ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
)
if err != nil {
return fmt.Errorf("remote command: %w", err)
}
tunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
}()
cmdDone := make(chan error, 1)
go func() {
cmdDone <- cmd.Run()
}()
select {
case err := <-tunnelClosed:
return fmt.Errorf("connection closed: %w", err)
case err := <-cmdDone:
if err != nil {
return fmt.Errorf("error retrieving logs: %w", err)
}
return nil // success
}
}

54
cmd/ghcs/main/main.go Normal file
View file

@ -0,0 +1,54 @@
package main
import (
"errors"
"fmt"
"io"
"net/http"
"os"
"github.com/github/ghcs/cmd/ghcs"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/spf13/cobra"
)
func main() {
token := os.Getenv("GITHUB_TOKEN")
rootCmd := ghcs.NewRootCmd(ghcs.NewApp(
output.NewLogger(os.Stdout, os.Stderr, false),
api.New(token, http.DefaultClient),
))
// Require GITHUB_TOKEN through a Cobra pre-run hook so that Cobra's help system for commands can still
// function without the token set.
oldPreRun := rootCmd.PersistentPreRunE
rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
if token == "" {
return errTokenMissing
}
if oldPreRun != nil {
return oldPreRun(cmd, args)
}
return nil
}
if cmd, err := rootCmd.ExecuteC(); err != nil {
explainError(os.Stderr, err, cmd)
os.Exit(1)
}
}
var errTokenMissing = errors.New("GITHUB_TOKEN is missing")
func explainError(w io.Writer, err error, cmd *cobra.Command) {
if errors.Is(err, errTokenMissing) {
fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo")
fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.")
return
}
if errors.Is(err, ghcs.ErrTooManyArgs) {
_ = cmd.Usage()
return
}
}

659
cmd/ghcs/mock_api.go Normal file
View file

@ -0,0 +1,659 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package ghcs
import (
"context"
"sync"
"github.com/github/ghcs/internal/api"
)
// apiClientMock is a mock implementation of apiClient.
//
// func TestSomethingThatUsesapiClient(t *testing.T) {
//
// // make and configure a mocked apiClient
// mockedapiClient := &apiClientMock{
// AuthorizedKeysFunc: func(ctx context.Context, user string) ([]byte, error) {
// panic("mock out the AuthorizedKeys method")
// },
// CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
// panic("mock out the CreateCodespace method")
// },
// DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error {
// panic("mock out the DeleteCodespace method")
// },
// GetCodespaceFunc: func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) {
// panic("mock out the GetCodespace method")
// },
// GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) {
// panic("mock out the GetCodespaceRegionLocation method")
// },
// GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) {
// panic("mock out the GetCodespaceRepositoryContents method")
// },
// GetCodespaceTokenFunc: func(ctx context.Context, user string, name string) (string, error) {
// panic("mock out the GetCodespaceToken method")
// },
// GetCodespacesSKUsFunc: func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) {
// panic("mock out the GetCodespacesSKUs method")
// },
// GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
// panic("mock out the GetRepository method")
// },
// GetUserFunc: func(ctx context.Context) (*api.User, error) {
// panic("mock out the GetUser method")
// },
// ListCodespacesFunc: func(ctx context.Context, user string) ([]*api.Codespace, error) {
// panic("mock out the ListCodespaces method")
// },
// StartCodespaceFunc: func(ctx context.Context, token string, codespace *api.Codespace) error {
// panic("mock out the StartCodespace method")
// },
// }
//
// // use mockedapiClient in code that requires apiClient
// // and then make assertions.
//
// }
type apiClientMock struct {
// AuthorizedKeysFunc mocks the AuthorizedKeys method.
AuthorizedKeysFunc func(ctx context.Context, user string) ([]byte, error)
// CreateCodespaceFunc mocks the CreateCodespace method.
CreateCodespaceFunc func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
// DeleteCodespaceFunc mocks the DeleteCodespace method.
DeleteCodespaceFunc func(ctx context.Context, user string, name string) error
// GetCodespaceFunc mocks the GetCodespace method.
GetCodespaceFunc func(ctx context.Context, token string, user string, name string) (*api.Codespace, error)
// GetCodespaceRegionLocationFunc mocks the GetCodespaceRegionLocation method.
GetCodespaceRegionLocationFunc func(ctx context.Context) (string, error)
// GetCodespaceRepositoryContentsFunc mocks the GetCodespaceRepositoryContents method.
GetCodespaceRepositoryContentsFunc func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
// GetCodespaceTokenFunc mocks the GetCodespaceToken method.
GetCodespaceTokenFunc func(ctx context.Context, user string, name string) (string, error)
// GetCodespacesSKUsFunc mocks the GetCodespacesSKUs method.
GetCodespacesSKUsFunc func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error)
// GetRepositoryFunc mocks the GetRepository method.
GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error)
// GetUserFunc mocks the GetUser method.
GetUserFunc func(ctx context.Context) (*api.User, error)
// ListCodespacesFunc mocks the ListCodespaces method.
ListCodespacesFunc func(ctx context.Context, user string) ([]*api.Codespace, error)
// StartCodespaceFunc mocks the StartCodespace method.
StartCodespaceFunc func(ctx context.Context, token string, codespace *api.Codespace) error
// calls tracks calls to the methods.
calls struct {
// AuthorizedKeys holds details about calls to the AuthorizedKeys method.
AuthorizedKeys []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User string
}
// CreateCodespace holds details about calls to the CreateCodespace method.
CreateCodespace []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Params is the params argument value.
Params *api.CreateCodespaceParams
}
// DeleteCodespace holds details about calls to the DeleteCodespace method.
DeleteCodespace []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User string
// Name is the name argument value.
Name string
}
// GetCodespace holds details about calls to the GetCodespace method.
GetCodespace []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Token is the token argument value.
Token string
// User is the user argument value.
User string
// Name is the name argument value.
Name string
}
// GetCodespaceRegionLocation holds details about calls to the GetCodespaceRegionLocation method.
GetCodespaceRegionLocation []struct {
// Ctx is the ctx argument value.
Ctx context.Context
}
// GetCodespaceRepositoryContents holds details about calls to the GetCodespaceRepositoryContents method.
GetCodespaceRepositoryContents []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Codespace is the codespace argument value.
Codespace *api.Codespace
// Path is the path argument value.
Path string
}
// GetCodespaceToken holds details about calls to the GetCodespaceToken method.
GetCodespaceToken []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User string
// Name is the name argument value.
Name string
}
// GetCodespacesSKUs holds details about calls to the GetCodespacesSKUs method.
GetCodespacesSKUs []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User *api.User
// Repository is the repository argument value.
Repository *api.Repository
// Branch is the branch argument value.
Branch string
// Location is the location argument value.
Location string
}
// GetRepository holds details about calls to the GetRepository method.
GetRepository []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Nwo is the nwo argument value.
Nwo string
}
// GetUser holds details about calls to the GetUser method.
GetUser []struct {
// Ctx is the ctx argument value.
Ctx context.Context
}
// ListCodespaces holds details about calls to the ListCodespaces method.
ListCodespaces []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User string
}
// StartCodespace holds details about calls to the StartCodespace method.
StartCodespace []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Token is the token argument value.
Token string
// Codespace is the codespace argument value.
Codespace *api.Codespace
}
}
lockAuthorizedKeys sync.RWMutex
lockCreateCodespace sync.RWMutex
lockDeleteCodespace sync.RWMutex
lockGetCodespace sync.RWMutex
lockGetCodespaceRegionLocation sync.RWMutex
lockGetCodespaceRepositoryContents sync.RWMutex
lockGetCodespaceToken sync.RWMutex
lockGetCodespacesSKUs sync.RWMutex
lockGetRepository sync.RWMutex
lockGetUser sync.RWMutex
lockListCodespaces sync.RWMutex
lockStartCodespace sync.RWMutex
}
// AuthorizedKeys calls AuthorizedKeysFunc.
func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
if mock.AuthorizedKeysFunc == nil {
panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called")
}
callInfo := struct {
Ctx context.Context
User string
}{
Ctx: ctx,
User: user,
}
mock.lockAuthorizedKeys.Lock()
mock.calls.AuthorizedKeys = append(mock.calls.AuthorizedKeys, callInfo)
mock.lockAuthorizedKeys.Unlock()
return mock.AuthorizedKeysFunc(ctx, user)
}
// AuthorizedKeysCalls gets all the calls that were made to AuthorizedKeys.
// Check the length with:
// len(mockedapiClient.AuthorizedKeysCalls())
func (mock *apiClientMock) AuthorizedKeysCalls() []struct {
Ctx context.Context
User string
} {
var calls []struct {
Ctx context.Context
User string
}
mock.lockAuthorizedKeys.RLock()
calls = mock.calls.AuthorizedKeys
mock.lockAuthorizedKeys.RUnlock()
return calls
}
// CreateCodespace calls CreateCodespaceFunc.
func (mock *apiClientMock) CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
if mock.CreateCodespaceFunc == nil {
panic("apiClientMock.CreateCodespaceFunc: method is nil but apiClient.CreateCodespace was just called")
}
callInfo := struct {
Ctx context.Context
Params *api.CreateCodespaceParams
}{
Ctx: ctx,
Params: params,
}
mock.lockCreateCodespace.Lock()
mock.calls.CreateCodespace = append(mock.calls.CreateCodespace, callInfo)
mock.lockCreateCodespace.Unlock()
return mock.CreateCodespaceFunc(ctx, params)
}
// CreateCodespaceCalls gets all the calls that were made to CreateCodespace.
// Check the length with:
// len(mockedapiClient.CreateCodespaceCalls())
func (mock *apiClientMock) CreateCodespaceCalls() []struct {
Ctx context.Context
Params *api.CreateCodespaceParams
} {
var calls []struct {
Ctx context.Context
Params *api.CreateCodespaceParams
}
mock.lockCreateCodespace.RLock()
calls = mock.calls.CreateCodespace
mock.lockCreateCodespace.RUnlock()
return calls
}
// DeleteCodespace calls DeleteCodespaceFunc.
func (mock *apiClientMock) DeleteCodespace(ctx context.Context, user string, name string) error {
if mock.DeleteCodespaceFunc == nil {
panic("apiClientMock.DeleteCodespaceFunc: method is nil but apiClient.DeleteCodespace was just called")
}
callInfo := struct {
Ctx context.Context
User string
Name string
}{
Ctx: ctx,
User: user,
Name: name,
}
mock.lockDeleteCodespace.Lock()
mock.calls.DeleteCodespace = append(mock.calls.DeleteCodespace, callInfo)
mock.lockDeleteCodespace.Unlock()
return mock.DeleteCodespaceFunc(ctx, user, name)
}
// DeleteCodespaceCalls gets all the calls that were made to DeleteCodespace.
// Check the length with:
// len(mockedapiClient.DeleteCodespaceCalls())
func (mock *apiClientMock) DeleteCodespaceCalls() []struct {
Ctx context.Context
User string
Name string
} {
var calls []struct {
Ctx context.Context
User string
Name string
}
mock.lockDeleteCodespace.RLock()
calls = mock.calls.DeleteCodespace
mock.lockDeleteCodespace.RUnlock()
return calls
}
// GetCodespace calls GetCodespaceFunc.
func (mock *apiClientMock) GetCodespace(ctx context.Context, token string, user string, name string) (*api.Codespace, error) {
if mock.GetCodespaceFunc == nil {
panic("apiClientMock.GetCodespaceFunc: method is nil but apiClient.GetCodespace was just called")
}
callInfo := struct {
Ctx context.Context
Token string
User string
Name string
}{
Ctx: ctx,
Token: token,
User: user,
Name: name,
}
mock.lockGetCodespace.Lock()
mock.calls.GetCodespace = append(mock.calls.GetCodespace, callInfo)
mock.lockGetCodespace.Unlock()
return mock.GetCodespaceFunc(ctx, token, user, name)
}
// GetCodespaceCalls gets all the calls that were made to GetCodespace.
// Check the length with:
// len(mockedapiClient.GetCodespaceCalls())
func (mock *apiClientMock) GetCodespaceCalls() []struct {
Ctx context.Context
Token string
User string
Name string
} {
var calls []struct {
Ctx context.Context
Token string
User string
Name string
}
mock.lockGetCodespace.RLock()
calls = mock.calls.GetCodespace
mock.lockGetCodespace.RUnlock()
return calls
}
// GetCodespaceRegionLocation calls GetCodespaceRegionLocationFunc.
func (mock *apiClientMock) GetCodespaceRegionLocation(ctx context.Context) (string, error) {
if mock.GetCodespaceRegionLocationFunc == nil {
panic("apiClientMock.GetCodespaceRegionLocationFunc: method is nil but apiClient.GetCodespaceRegionLocation was just called")
}
callInfo := struct {
Ctx context.Context
}{
Ctx: ctx,
}
mock.lockGetCodespaceRegionLocation.Lock()
mock.calls.GetCodespaceRegionLocation = append(mock.calls.GetCodespaceRegionLocation, callInfo)
mock.lockGetCodespaceRegionLocation.Unlock()
return mock.GetCodespaceRegionLocationFunc(ctx)
}
// GetCodespaceRegionLocationCalls gets all the calls that were made to GetCodespaceRegionLocation.
// Check the length with:
// len(mockedapiClient.GetCodespaceRegionLocationCalls())
func (mock *apiClientMock) GetCodespaceRegionLocationCalls() []struct {
Ctx context.Context
} {
var calls []struct {
Ctx context.Context
}
mock.lockGetCodespaceRegionLocation.RLock()
calls = mock.calls.GetCodespaceRegionLocation
mock.lockGetCodespaceRegionLocation.RUnlock()
return calls
}
// GetCodespaceRepositoryContents calls GetCodespaceRepositoryContentsFunc.
func (mock *apiClientMock) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) {
if mock.GetCodespaceRepositoryContentsFunc == nil {
panic("apiClientMock.GetCodespaceRepositoryContentsFunc: method is nil but apiClient.GetCodespaceRepositoryContents was just called")
}
callInfo := struct {
Ctx context.Context
Codespace *api.Codespace
Path string
}{
Ctx: ctx,
Codespace: codespace,
Path: path,
}
mock.lockGetCodespaceRepositoryContents.Lock()
mock.calls.GetCodespaceRepositoryContents = append(mock.calls.GetCodespaceRepositoryContents, callInfo)
mock.lockGetCodespaceRepositoryContents.Unlock()
return mock.GetCodespaceRepositoryContentsFunc(ctx, codespace, path)
}
// GetCodespaceRepositoryContentsCalls gets all the calls that were made to GetCodespaceRepositoryContents.
// Check the length with:
// len(mockedapiClient.GetCodespaceRepositoryContentsCalls())
func (mock *apiClientMock) GetCodespaceRepositoryContentsCalls() []struct {
Ctx context.Context
Codespace *api.Codespace
Path string
} {
var calls []struct {
Ctx context.Context
Codespace *api.Codespace
Path string
}
mock.lockGetCodespaceRepositoryContents.RLock()
calls = mock.calls.GetCodespaceRepositoryContents
mock.lockGetCodespaceRepositoryContents.RUnlock()
return calls
}
// GetCodespaceToken calls GetCodespaceTokenFunc.
func (mock *apiClientMock) GetCodespaceToken(ctx context.Context, user string, name string) (string, error) {
if mock.GetCodespaceTokenFunc == nil {
panic("apiClientMock.GetCodespaceTokenFunc: method is nil but apiClient.GetCodespaceToken was just called")
}
callInfo := struct {
Ctx context.Context
User string
Name string
}{
Ctx: ctx,
User: user,
Name: name,
}
mock.lockGetCodespaceToken.Lock()
mock.calls.GetCodespaceToken = append(mock.calls.GetCodespaceToken, callInfo)
mock.lockGetCodespaceToken.Unlock()
return mock.GetCodespaceTokenFunc(ctx, user, name)
}
// GetCodespaceTokenCalls gets all the calls that were made to GetCodespaceToken.
// Check the length with:
// len(mockedapiClient.GetCodespaceTokenCalls())
func (mock *apiClientMock) GetCodespaceTokenCalls() []struct {
Ctx context.Context
User string
Name string
} {
var calls []struct {
Ctx context.Context
User string
Name string
}
mock.lockGetCodespaceToken.RLock()
calls = mock.calls.GetCodespaceToken
mock.lockGetCodespaceToken.RUnlock()
return calls
}
// GetCodespacesSKUs calls GetCodespacesSKUsFunc.
func (mock *apiClientMock) GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) {
if mock.GetCodespacesSKUsFunc == nil {
panic("apiClientMock.GetCodespacesSKUsFunc: method is nil but apiClient.GetCodespacesSKUs was just called")
}
callInfo := struct {
Ctx context.Context
User *api.User
Repository *api.Repository
Branch string
Location string
}{
Ctx: ctx,
User: user,
Repository: repository,
Branch: branch,
Location: location,
}
mock.lockGetCodespacesSKUs.Lock()
mock.calls.GetCodespacesSKUs = append(mock.calls.GetCodespacesSKUs, callInfo)
mock.lockGetCodespacesSKUs.Unlock()
return mock.GetCodespacesSKUsFunc(ctx, user, repository, branch, location)
}
// GetCodespacesSKUsCalls gets all the calls that were made to GetCodespacesSKUs.
// Check the length with:
// len(mockedapiClient.GetCodespacesSKUsCalls())
func (mock *apiClientMock) GetCodespacesSKUsCalls() []struct {
Ctx context.Context
User *api.User
Repository *api.Repository
Branch string
Location string
} {
var calls []struct {
Ctx context.Context
User *api.User
Repository *api.Repository
Branch string
Location string
}
mock.lockGetCodespacesSKUs.RLock()
calls = mock.calls.GetCodespacesSKUs
mock.lockGetCodespacesSKUs.RUnlock()
return calls
}
// GetRepository calls GetRepositoryFunc.
func (mock *apiClientMock) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) {
if mock.GetRepositoryFunc == nil {
panic("apiClientMock.GetRepositoryFunc: method is nil but apiClient.GetRepository was just called")
}
callInfo := struct {
Ctx context.Context
Nwo string
}{
Ctx: ctx,
Nwo: nwo,
}
mock.lockGetRepository.Lock()
mock.calls.GetRepository = append(mock.calls.GetRepository, callInfo)
mock.lockGetRepository.Unlock()
return mock.GetRepositoryFunc(ctx, nwo)
}
// GetRepositoryCalls gets all the calls that were made to GetRepository.
// Check the length with:
// len(mockedapiClient.GetRepositoryCalls())
func (mock *apiClientMock) GetRepositoryCalls() []struct {
Ctx context.Context
Nwo string
} {
var calls []struct {
Ctx context.Context
Nwo string
}
mock.lockGetRepository.RLock()
calls = mock.calls.GetRepository
mock.lockGetRepository.RUnlock()
return calls
}
// GetUser calls GetUserFunc.
func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) {
if mock.GetUserFunc == nil {
panic("apiClientMock.GetUserFunc: method is nil but apiClient.GetUser was just called")
}
callInfo := struct {
Ctx context.Context
}{
Ctx: ctx,
}
mock.lockGetUser.Lock()
mock.calls.GetUser = append(mock.calls.GetUser, callInfo)
mock.lockGetUser.Unlock()
return mock.GetUserFunc(ctx)
}
// GetUserCalls gets all the calls that were made to GetUser.
// Check the length with:
// len(mockedapiClient.GetUserCalls())
func (mock *apiClientMock) GetUserCalls() []struct {
Ctx context.Context
} {
var calls []struct {
Ctx context.Context
}
mock.lockGetUser.RLock()
calls = mock.calls.GetUser
mock.lockGetUser.RUnlock()
return calls
}
// ListCodespaces calls ListCodespacesFunc.
func (mock *apiClientMock) ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) {
if mock.ListCodespacesFunc == nil {
panic("apiClientMock.ListCodespacesFunc: method is nil but apiClient.ListCodespaces was just called")
}
callInfo := struct {
Ctx context.Context
User string
}{
Ctx: ctx,
User: user,
}
mock.lockListCodespaces.Lock()
mock.calls.ListCodespaces = append(mock.calls.ListCodespaces, callInfo)
mock.lockListCodespaces.Unlock()
return mock.ListCodespacesFunc(ctx, user)
}
// ListCodespacesCalls gets all the calls that were made to ListCodespaces.
// Check the length with:
// len(mockedapiClient.ListCodespacesCalls())
func (mock *apiClientMock) ListCodespacesCalls() []struct {
Ctx context.Context
User string
} {
var calls []struct {
Ctx context.Context
User string
}
mock.lockListCodespaces.RLock()
calls = mock.calls.ListCodespaces
mock.lockListCodespaces.RUnlock()
return calls
}
// StartCodespace calls StartCodespaceFunc.
func (mock *apiClientMock) StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error {
if mock.StartCodespaceFunc == nil {
panic("apiClientMock.StartCodespaceFunc: method is nil but apiClient.StartCodespace was just called")
}
callInfo := struct {
Ctx context.Context
Token string
Codespace *api.Codespace
}{
Ctx: ctx,
Token: token,
Codespace: codespace,
}
mock.lockStartCodespace.Lock()
mock.calls.StartCodespace = append(mock.calls.StartCodespace, callInfo)
mock.lockStartCodespace.Unlock()
return mock.StartCodespaceFunc(ctx, token, codespace)
}
// StartCodespaceCalls gets all the calls that were made to StartCodespace.
// Check the length with:
// len(mockedapiClient.StartCodespaceCalls())
func (mock *apiClientMock) StartCodespaceCalls() []struct {
Ctx context.Context
Token string
Codespace *api.Codespace
} {
var calls []struct {
Ctx context.Context
Token string
Codespace *api.Codespace
}
mock.lockStartCodespace.RLock()
calls = mock.calls.StartCodespace
mock.lockStartCodespace.RUnlock()
return calls
}

69
cmd/ghcs/mock_prompter.go Normal file
View file

@ -0,0 +1,69 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package ghcs
import (
"sync"
)
// prompterMock is a mock implementation of prompter.
//
// func TestSomethingThatUsesprompter(t *testing.T) {
//
// // make and configure a mocked prompter
// mockedprompter := &prompterMock{
// ConfirmFunc: func(message string) (bool, error) {
// panic("mock out the Confirm method")
// },
// }
//
// // use mockedprompter in code that requires prompter
// // and then make assertions.
//
// }
type prompterMock struct {
// ConfirmFunc mocks the Confirm method.
ConfirmFunc func(message string) (bool, error)
// calls tracks calls to the methods.
calls struct {
// Confirm holds details about calls to the Confirm method.
Confirm []struct {
// Message is the message argument value.
Message string
}
}
lockConfirm sync.RWMutex
}
// Confirm calls ConfirmFunc.
func (mock *prompterMock) Confirm(message string) (bool, error) {
if mock.ConfirmFunc == nil {
panic("prompterMock.ConfirmFunc: method is nil but prompter.Confirm was just called")
}
callInfo := struct {
Message string
}{
Message: message,
}
mock.lockConfirm.Lock()
mock.calls.Confirm = append(mock.calls.Confirm, callInfo)
mock.lockConfirm.Unlock()
return mock.ConfirmFunc(message)
}
// ConfirmCalls gets all the calls that were made to Confirm.
// Check the length with:
// len(mockedprompter.ConfirmCalls())
func (mock *prompterMock) ConfirmCalls() []struct {
Message string
} {
var calls []struct {
Message string
}
mock.lockConfirm.RLock()
calls = mock.calls.Confirm
mock.lockConfirm.RUnlock()
return calls
}

View file

@ -0,0 +1,55 @@
package output
import (
"encoding/json"
"io"
"strings"
"unicode"
)
type jsonwriter struct {
w io.Writer
pretty bool
cols []string
data []interface{}
}
func (j *jsonwriter) SetHeader(cols []string) {
j.cols = cols
}
func (j *jsonwriter) Append(values []string) {
row := make(map[string]string)
for i, v := range values {
row[camelize(j.cols[i])] = v
}
j.data = append(j.data, row)
}
func (j *jsonwriter) Render() {
enc := json.NewEncoder(j.w)
if j.pretty {
enc.SetIndent("", " ")
}
_ = enc.Encode(j.data)
}
func camelize(s string) string {
var b strings.Builder
capitalizeNext := false
for i, r := range s {
if r == ' ' {
capitalizeNext = true
continue
}
if capitalizeNext {
b.WriteRune(unicode.ToUpper(r))
capitalizeNext = false
} else if i == 0 {
b.WriteRune(unicode.ToLower(r))
} else {
b.WriteRune(r)
}
}
return b.String()
}

View file

@ -0,0 +1,31 @@
package output
import (
"io"
"os"
"github.com/olekukonko/tablewriter"
"golang.org/x/term"
)
type Table interface {
SetHeader([]string)
Append([]string)
Render()
}
func NewTable(w io.Writer, asJSON bool) Table {
isTTY := isTTY(w)
if asJSON {
return &jsonwriter{w: w, pretty: isTTY}
}
if isTTY {
return tablewriter.NewWriter(w)
}
return &tabwriter{w: w}
}
func isTTY(w io.Writer) bool {
f, ok := w.(*os.File)
return ok && term.IsTerminal(int(f.Fd()))
}

View file

@ -0,0 +1,25 @@
package output
import (
"fmt"
"io"
)
type tabwriter struct {
w io.Writer
}
func (j *tabwriter) SetHeader([]string) {}
func (j *tabwriter) Append(values []string) {
var sep string
for i, v := range values {
if i == 1 {
sep = "\t"
}
fmt.Fprintf(j.w, "%s%s", sep, v)
}
fmt.Fprint(j.w, "\n")
}
func (j *tabwriter) Render() {}

59
cmd/ghcs/output/logger.go Normal file
View file

@ -0,0 +1,59 @@
package output
import (
"fmt"
"io"
)
// NewLogger returns a Logger that will write to the given stdout/stderr writers.
// Disable the Logger to prevent it from writing to stdout in a TTY environment.
func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger {
return &Logger{
out: stdout,
errout: stderr,
enabled: !disabled && isTTY(stdout),
}
}
// Logger writes to the given stdout/stderr writers.
// If not enabled, Print functions will noop but Error functions will continue
// to write to the stderr writer.
type Logger struct {
out io.Writer
errout io.Writer
enabled bool
}
// Print writes the arguments to the stdout writer.
func (l *Logger) Print(v ...interface{}) (int, error) {
if !l.enabled {
return 0, nil
}
return fmt.Fprint(l.out, v...)
}
// Println writes the arguments to the stdout writer with a newline at the end.
func (l *Logger) Println(v ...interface{}) (int, error) {
if !l.enabled {
return 0, nil
}
return fmt.Fprintln(l.out, v...)
}
// Printf writes the formatted arguments to the stdout writer.
func (l *Logger) Printf(f string, v ...interface{}) (int, error) {
if !l.enabled {
return 0, nil
}
return fmt.Fprintf(l.out, f, v...)
}
// Errorf writes the formatted arguments to the stderr writer.
func (l *Logger) Errorf(f string, v ...interface{}) (int, error) {
return fmt.Fprintf(l.errout, f, v...)
}
// Errorln writes the arguments to the stderr writer with a newline at the end.
func (l *Logger) Errorln(v ...interface{}) (int, error) {
return fmt.Fprintln(l.errout, v...)
}

330
cmd/ghcs/ports.go Normal file
View file

@ -0,0 +1,330 @@
package ghcs
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"os"
"strconv"
"strings"
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/github/ghcs/internal/codespaces"
"github.com/github/ghcs/internal/liveshare"
"github.com/muhammadmuzzammil1998/jsonc"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
// newPortsCmd returns a Cobra "ports" command that displays a table of available ports,
// according to the specified flags.
func newPortsCmd(app *App) *cobra.Command {
var (
codespace string
asJSON bool
)
portsCmd := &cobra.Command{
Use: "ports",
Short: "List ports in a codespace",
Args: noArgsConstraint,
RunE: func(cmd *cobra.Command, args []string) error {
return app.ListPorts(cmd.Context(), codespace, asJSON)
},
}
portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace")
portsCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON")
portsCmd.AddCommand(newPortsPublicCmd(app))
portsCmd.AddCommand(newPortsPrivateCmd(app))
portsCmd.AddCommand(newPortsForwardCmd(app))
return portsCmd
}
// ListPorts lists known ports in a codespace.
func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool) (err error) {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
// TODO(josebalius): remove special handling of this error here and it other places
if err == errNoCodespaces {
return err
}
return fmt.Errorf("error choosing codespace: %w", err)
}
devContainerCh := getDevContainer(ctx, a.apiClient, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
a.logger.Println("Loading ports...")
ports, err := session.GetSharedServers(ctx)
if err != nil {
return fmt.Errorf("error getting ports of shared servers: %w", err)
}
devContainerResult := <-devContainerCh
if devContainerResult.err != nil {
// Warn about failure to read the devcontainer file. Not a ghcs command error.
_, _ = a.logger.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error())
}
table := output.NewTable(os.Stdout, asJSON)
table.SetHeader([]string{"Label", "Port", "Public", "Browse URL"})
for _, port := range ports {
sourcePort := strconv.Itoa(port.SourcePort)
var portName string
if devContainerResult.devContainer != nil {
if attributes, ok := devContainerResult.devContainer.PortAttributes[sourcePort]; ok {
portName = attributes.Label
}
}
table.Append([]string{
portName,
sourcePort,
strings.ToUpper(strconv.FormatBool(port.IsPublic)),
fmt.Sprintf("https://%s-%s.githubpreview.dev/", codespace.Name, sourcePort),
})
}
table.Render()
return nil
}
type devContainerResult struct {
devContainer *devContainer
err error
}
type devContainer struct {
PortAttributes map[string]portAttribute `json:"portsAttributes"`
}
type portAttribute struct {
Label string `json:"label"`
}
func getDevContainer(ctx context.Context, apiClient apiClient, codespace *api.Codespace) <-chan devContainerResult {
ch := make(chan devContainerResult, 1)
go func() {
contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json")
if err != nil {
ch <- devContainerResult{nil, fmt.Errorf("error getting content: %w", err)}
return
}
if contents == nil {
ch <- devContainerResult{nil, nil}
return
}
convertedJSON := normalizeJSON(jsonc.ToJSON(contents))
if !jsonc.Valid(convertedJSON) {
ch <- devContainerResult{nil, errors.New("failed to convert json to standard json")}
return
}
var container devContainer
if err := json.Unmarshal(convertedJSON, &container); err != nil {
ch <- devContainerResult{nil, fmt.Errorf("error unmarshaling: %w", err)}
return
}
ch <- devContainerResult{&container, nil}
}()
return ch
}
// newPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public.
func newPortsPublicCmd(app *App) *cobra.Command {
return &cobra.Command{
Use: "public <port>",
Short: "Mark port as public",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
codespace, err := cmd.Flags().GetString("codespace")
if err != nil {
// should only happen if flag is not defined
// or if the flag is not of string type
// since it's a persistent flag that we control it should never happen
return fmt.Errorf("get codespace flag: %w", err)
}
return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], true)
},
}
}
// newPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private.
func newPortsPrivateCmd(app *App) *cobra.Command {
return &cobra.Command{
Use: "private <port>",
Short: "Mark port as private",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
codespace, err := cmd.Flags().GetString("codespace")
if err != nil {
// should only happen if flag is not defined
// or if the flag is not of string type
// since it's a persistent flag that we control it should never happen
return fmt.Errorf("get codespace flag: %w", err)
}
return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], false)
},
}
}
func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePort string, public bool) (err error) {
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
if err == errNoCodespaces {
return err
}
return fmt.Errorf("error getting codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
port, err := strconv.Atoi(sourcePort)
if err != nil {
return fmt.Errorf("error reading port number: %w", err)
}
if err := session.UpdateSharedVisibility(ctx, port, public); err != nil {
return fmt.Errorf("error update port to public: %w", err)
}
state := "PUBLIC"
if !public {
state = "PRIVATE"
}
a.logger.Printf("Port %s is now %s.\n", sourcePort, state)
return nil
}
// NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of
// port pairs from the codespace to localhost.
func newPortsForwardCmd(app *App) *cobra.Command {
return &cobra.Command{
Use: "forward <remote-port>:<local-port>...",
Short: "Forward ports",
Args: cobra.MinimumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
codespace, err := cmd.Flags().GetString("codespace")
if err != nil {
// should only happen if flag is not defined
// or if the flag is not of string type
// since it's a persistent flag that we control it should never happen
return fmt.Errorf("get codespace flag: %w", err)
}
return app.ForwardPorts(cmd.Context(), codespace, args)
},
}
}
func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []string) (err error) {
portPairs, err := getPortPairs(ports)
if err != nil {
return fmt.Errorf("get port pairs: %w", err)
}
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
if err == errNoCodespaces {
return err
}
return fmt.Errorf("error getting codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
// Run forwarding of all ports concurrently, aborting all of
// them at the first failure, including cancellation of the context.
group, ctx := errgroup.WithContext(ctx)
for _, pair := range portPairs {
pair := pair
group.Go(func() error {
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local))
if err != nil {
return err
}
defer listen.Close()
a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)
name := fmt.Sprintf("share-%d", pair.remote)
fwd := liveshare.NewPortForwarder(session, name, pair.remote)
return fwd.ForwardToListener(ctx, listen) // error always non-nil
})
}
return group.Wait() // first error
}
type portPair struct {
remote, local int
}
// getPortPairs parses a list of strings of form "%d:%d" into pairs of (remote, local) numbers.
func getPortPairs(ports []string) ([]portPair, error) {
pp := make([]portPair, 0, len(ports))
for _, portString := range ports {
parts := strings.Split(portString, ":")
if len(parts) < 2 {
return nil, fmt.Errorf("port pair: %q is not valid", portString)
}
remote, err := strconv.Atoi(parts[0])
if err != nil {
return pp, fmt.Errorf("convert remote port to int: %w", err)
}
local, err := strconv.Atoi(parts[1])
if err != nil {
return pp, fmt.Errorf("convert local port to int: %w", err)
}
pp = append(pp, portPair{remote, local})
}
return pp, nil
}
func normalizeJSON(j []byte) []byte {
// remove trailing commas
return bytes.ReplaceAll(j, []byte("},}"), []byte("}}"))
}

30
cmd/ghcs/root.go Normal file
View file

@ -0,0 +1,30 @@
package ghcs
import (
"github.com/spf13/cobra"
)
var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number.
func NewRootCmd(app *App) *cobra.Command {
root := &cobra.Command{
Use: "ghcs",
SilenceUsage: true, // don't print usage message after each error (see #80)
SilenceErrors: false, // print errors automatically so that main need not
Long: `Unofficial CLI tool to manage GitHub Codespaces.
Running commands requires the GITHUB_TOKEN environment variable to be set to a
token to access the GitHub API with.`,
Version: version,
}
root.AddCommand(newCodeCmd(app))
root.AddCommand(newCreateCmd(app))
root.AddCommand(newDeleteCmd(app))
root.AddCommand(newListCmd(app))
root.AddCommand(newLogsCmd(app))
root.AddCommand(newPortsCmd(app))
root.AddCommand(newSSHCmd(app))
return root
}

105
cmd/ghcs/ssh.go Normal file
View file

@ -0,0 +1,105 @@
package ghcs
import (
"context"
"fmt"
"net"
"github.com/github/ghcs/internal/codespaces"
"github.com/github/ghcs/internal/liveshare"
"github.com/spf13/cobra"
)
func newSSHCmd(app *App) *cobra.Command {
var sshProfile, codespaceName string
var sshServerPort int
sshCmd := &cobra.Command{
Use: "ssh [flags] [--] [ssh-flags] [command]",
Short: "SSH into a codespace",
RunE: func(cmd *cobra.Command, args []string) error {
return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort)
},
}
sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "Name of the SSH profile to use")
sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number (0 => pick unused)")
sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Name of the codespace")
return sshCmd
}
// SSH opens an ssh session or runs an ssh command in a codespace.
func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) {
// Ensure all child tasks (e.g. port forwarding) terminate before return.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
authkeys := make(chan error, 1)
go func() {
authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login)
}()
codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName)
if err != nil {
return fmt.Errorf("get or choose codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
if err := <-authkeys; err != nil {
return err
}
a.logger.Println("Fetching SSH Details...")
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)
}
usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell
// Ensure local port is listening before client (Shell) connects.
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localSSHServerPort))
if err != nil {
return err
}
defer listen.Close()
localSSHServerPort = listen.Addr().(*net.TCPAddr).Port
connectDestination := sshProfile
if connectDestination == "" {
connectDestination = fmt.Sprintf("%s@localhost", sshUser)
}
a.logger.Println("Ready...")
tunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil
}()
shellClosed := make(chan error, 1)
go func() {
shellClosed <- codespaces.Shell(ctx, a.logger, sshArgs, localSSHServerPort, connectDestination, usingCustomPort)
}()
select {
case err := <-tunnelClosed:
return fmt.Errorf("tunnel closed: %w", err)
case err := <-shellClosed:
if err != nil {
return fmt.Errorf("shell closed: %w", err)
}
return nil // success
}
}