diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 1f080eb5d..f5cda54d4 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -149,7 +149,7 @@ jobs: GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} - name: Prepare PATH id: setupmsbuild - uses: microsoft/setup-msbuild@v1.1.3 + uses: microsoft/setup-msbuild@v1.3.1 - name: Build MSI id: buildmsi shell: bash diff --git a/.goreleaser.yml b/.goreleaser.yml index 5eb035433..860b337b3 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -16,7 +16,6 @@ builds: main: ./cmd/gh ldflags: - -s -w -X github.com/cli/cli/v2/internal/build.Version={{.Version}} -X github.com/cli/cli/v2/internal/build.Date={{time "2006-01-02"}} - - -X main.updaterEnabled=cli/cli id: macos goos: [darwin] goarch: [amd64] diff --git a/README.md b/README.md index 21e7269d8..c95722466 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ For [installation options see below](#installation), for usage instructions [see If anything feels off, or if you feel that some functionality is missing, please check out the [contributing page][contributing]. There you will find instructions for sharing your feedback, building the tool locally, and submitting pull requests to the project. +If you are a hubber and are interested in shipping new commands for the CLI, check out our [doc on internal contributions][intake-doc]. + ## Installation @@ -128,3 +130,4 @@ tool. Check out our [more detailed explanation][gh-vs-hub] to learn more. [contributing]: ./.github/CONTRIBUTING.md [gh-vs-hub]: ./docs/gh-vs-hub.md [build from source]: ./docs/source.md +[intake-doc]: ./docs/working-with-us.md diff --git a/api/http_client.go b/api/http_client.go index 81693cbd1..83f228409 100644 --- a/api/http_client.go +++ b/api/http_client.go @@ -64,6 +64,8 @@ func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) { client.Transport = AddAuthTokenHeader(client.Transport, opts.Config) } + client.Transport = AddASCIISanitizer(client.Transport) + return client, nil } diff --git a/api/queries_issue.go b/api/queries_issue.go index 1fceac39e..58053f8df 100644 --- a/api/queries_issue.go +++ b/api/queries_issue.go @@ -40,6 +40,7 @@ type Issue struct { Assignees Assignees Labels Labels ProjectCards ProjectCards + ProjectItems ProjectItems Milestone *Milestone ReactionGroups ReactionGroups IsPinned bool @@ -86,6 +87,10 @@ type ProjectCards struct { TotalCount int } +type ProjectItems struct { + Nodes []*ProjectV2Item +} + type ProjectInfo struct { Project struct { Name string `json:"name"` @@ -95,6 +100,14 @@ type ProjectInfo struct { } `json:"column"` } +type ProjectV2Item struct { + ID string `json:"id"` + Project struct { + ID string `json:"id"` + Title string `json:"title"` + } +} + func (p ProjectCards) ProjectNames() []string { names := make([]string, len(p.Nodes)) for i, c := range p.Nodes { @@ -103,6 +116,14 @@ func (p ProjectCards) ProjectNames() []string { return names } +func (p ProjectItems) ProjectTitles() []string { + titles := make([]string, len(p.Nodes)) + for i, c := range p.Nodes { + titles[i] = c.Project.Title + } + return titles +} + type Milestone struct { Number int `json:"number"` Title string `json:"title"` @@ -158,6 +179,7 @@ func IssueCreate(client *Client, repo *Repository, params map[string]interface{} mutation IssueCreate($input: CreateIssueInput!) { createIssue(input: $input) { issue { + id url } } @@ -167,7 +189,13 @@ func IssueCreate(client *Client, repo *Repository, params map[string]interface{} "repositoryId": repo.ID, } for key, val := range params { - inputParams[key] = val + switch key { + case "assigneeIds", "body", "issueTemplate", "labelIds", "milestoneId", "projectIds", "repositoryId", "title": + inputParams[key] = val + case "projectV2Ids": + default: + return nil, fmt.Errorf("invalid IssueCreate mutation parameter %s", key) + } } variables := map[string]interface{}{ "input": inputParams, @@ -183,8 +211,23 @@ func IssueCreate(client *Client, repo *Repository, params map[string]interface{} if err != nil { return nil, err } + issue := &result.CreateIssue.Issue - return &result.CreateIssue.Issue, nil + // projectV2 parameters aren't supported in the `createIssue` mutation, + // so add them after the issue has been created. + projectV2Ids, ok := params["projectV2Ids"].([]string) + if ok { + projectItems := make(map[string]string, len(projectV2Ids)) + for _, p := range projectV2Ids { + projectItems[p] = issue.ID + } + err = UpdateProjectV2Items(client, repo, projectItems, nil) + if err != nil { + return issue, err + } + } + + return issue, nil } type IssueStatusOptions struct { diff --git a/api/queries_org.go b/api/queries_org.go index 502b6f39e..e57d7f08b 100644 --- a/api/queries_org.go +++ b/api/queries_org.go @@ -5,7 +5,7 @@ import ( "github.com/shurcooL/githubv4" ) -// OrganizationProjects fetches all open projects for an organization +// OrganizationProjects fetches all open projects for an organization. func OrganizationProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, error) { type responseData struct { Organization struct { @@ -42,6 +42,45 @@ func OrganizationProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, return projects, nil } +// OrganizationProjectsV2 fetches all open projectsV2 for an organization. +func OrganizationProjectsV2(client *Client, repo ghrepo.Interface) ([]RepoProjectV2, error) { + type responseData struct { + Organization struct { + ProjectsV2 struct { + Nodes []RepoProjectV2 + PageInfo struct { + HasNextPage bool + EndCursor string + } + } `graphql:"projectsV2(first: 100, orderBy: {field: TITLE, direction: ASC}, after: $endCursor, query: $query)"` + } `graphql:"organization(login: $owner)"` + } + + variables := map[string]interface{}{ + "owner": githubv4.String(repo.RepoOwner()), + "endCursor": (*githubv4.String)(nil), + "query": githubv4.String("is:open"), + } + + var projectsV2 []RepoProjectV2 + for { + var query responseData + err := client.Query(repo.RepoHost(), "OrganizationProjectV2List", &query, variables) + if err != nil { + return nil, err + } + + projectsV2 = append(projectsV2, query.Organization.ProjectsV2.Nodes...) + + if !query.Organization.ProjectsV2.PageInfo.HasNextPage { + break + } + variables["endCursor"] = githubv4.String(query.Organization.ProjectsV2.PageInfo.EndCursor) + } + + return projectsV2, nil +} + type OrgTeam struct { ID string Slug string diff --git a/api/queries_pr.go b/api/queries_pr.go index af4e6bb0f..87ffd658d 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -74,6 +74,7 @@ type PullRequest struct { Assignees Assignees Labels Labels ProjectCards ProjectCards + ProjectItems ProjectItems Milestone *Milestone Comments Comments ReactionGroups ReactionGroups @@ -378,6 +379,19 @@ func CreatePullRequest(client *Client, repo *Repository, params map[string]inter } } + // projectsV2 are added in yet another mutation + projectV2Ids, ok := params["projectV2Ids"].([]string) + if ok { + projectItems := make(map[string]string, len(projectV2Ids)) + for _, p := range projectV2Ids { + projectItems[p] = pr.ID + } + err = UpdateProjectV2Items(client, repo, projectItems, nil) + if err != nil { + return pr, err + } + } + return pr, nil } diff --git a/api/queries_pr_review.go b/api/queries_pr_review.go index aa2b7fedb..d5565b54b 100644 --- a/api/queries_pr_review.go +++ b/api/queries_pr_review.go @@ -31,7 +31,7 @@ type PullRequestReviews struct { type PullRequestReview struct { ID string `json:"id"` - Author Author `json:"author"` + Author CommentAuthor `json:"author"` AuthorAssociation string `json:"authorAssociation"` Body string `json:"body"` SubmittedAt *time.Time `json:"submittedAt"` diff --git a/api/queries_projects_v2.go b/api/queries_projects_v2.go new file mode 100644 index 000000000..e3b214f69 --- /dev/null +++ b/api/queries_projects_v2.go @@ -0,0 +1,144 @@ +package api + +import ( + "fmt" + "strings" + + "github.com/cli/cli/v2/internal/ghrepo" + "github.com/shurcooL/githubv4" +) + +const ( + errorProjectsV2ReadScope = "field requires one of the following scopes: ['read:project']" + errorProjectsV2RepositoryField = "Field 'projectsV2' doesn't exist on type 'Repository'" + errorProjectsV2OrganizationField = "Field 'projectsV2' doesn't exist on type 'Organization'" + errorProjectsV2IssueField = "Field 'projectItems' doesn't exist on type 'Issue'" + errorProjectsV2PullRequestField = "Field 'projectItems' doesn't exist on type 'PullRequest'" +) + +// UpdateProjectV2Items uses the addProjectV2ItemById and the deleteProjectV2Item mutations +// to add and delete items from projects. The addProjectItems and deleteProjectItems arguments are +// mappings between a project and an item. This function can be used across multiple projects +// and items. Note that the deleteProjectV2Item mutation requires the item id from the project not +// the global id. +func UpdateProjectV2Items(client *Client, repo ghrepo.Interface, addProjectItems, deleteProjectItems map[string]string) error { + l := len(addProjectItems) + len(deleteProjectItems) + if l == 0 { + return nil + } + inputs := make([]string, 0, l) + mutations := make([]string, 0, l) + variables := make(map[string]interface{}, l) + var i int + + for project, item := range addProjectItems { + inputs = append(inputs, fmt.Sprintf("$input_%03d: AddProjectV2ItemByIdInput!", i)) + mutations = append(mutations, fmt.Sprintf("add_%03d: addProjectV2ItemById(input: $input_%03d) { item { id } }", i, i)) + variables[fmt.Sprintf("input_%03d", i)] = map[string]interface{}{"contentId": item, "projectId": project} + i++ + } + + for project, item := range deleteProjectItems { + inputs = append(inputs, fmt.Sprintf("$input_%03d: DeleteProjectV2ItemInput!", i)) + mutations = append(mutations, fmt.Sprintf("delete_%03d: deleteProjectV2Item(input: $input_%03d) { deletedItemId }", i, i)) + variables[fmt.Sprintf("input_%03d", i)] = map[string]interface{}{"itemId": item, "projectId": project} + i++ + } + + query := fmt.Sprintf(`mutation UpdateProjectV2Items(%s) {%s}`, strings.Join(inputs, " "), strings.Join(mutations, " ")) + + return client.GraphQL(repo.RepoHost(), query, variables, nil) +} + +// ProjectsV2ItemsForIssue fetches all ProjectItems for an issue. +func ProjectsV2ItemsForIssue(client *Client, repo ghrepo.Interface, issue *Issue) error { + type response struct { + Repository struct { + Issue struct { + ProjectItems struct { + Nodes []*ProjectV2Item + PageInfo struct { + HasNextPage bool + EndCursor string + } + } `graphql:"projectItems(first: 100, after: $endCursor)"` + } `graphql:"issue(number: $number)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } + variables := map[string]interface{}{ + "owner": githubv4.String(repo.RepoOwner()), + "name": githubv4.String(repo.RepoName()), + "number": githubv4.Int(issue.Number), + "endCursor": (*githubv4.String)(nil), + } + var items ProjectItems + for { + var query response + err := client.Query(repo.RepoHost(), "IssueProjectItems", &query, variables) + if err != nil { + return err + } + items.Nodes = append(items.Nodes, query.Repository.Issue.ProjectItems.Nodes...) + if !query.Repository.Issue.ProjectItems.PageInfo.HasNextPage { + break + } + variables["endCursor"] = githubv4.String(query.Repository.Issue.ProjectItems.PageInfo.EndCursor) + } + issue.ProjectItems = items + return nil +} + +// ProjectsV2ItemsForPullRequest fetches all ProjectItems for a pull request. +func ProjectsV2ItemsForPullRequest(client *Client, repo ghrepo.Interface, pr *PullRequest) error { + type response struct { + Repository struct { + PullRequest struct { + ProjectItems struct { + Nodes []*ProjectV2Item + PageInfo struct { + HasNextPage bool + EndCursor string + } + } `graphql:"projectItems(first: 100, after: $endCursor)"` + } `graphql:"pullRequest(number: $number)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } + variables := map[string]interface{}{ + "owner": githubv4.String(repo.RepoOwner()), + "name": githubv4.String(repo.RepoName()), + "number": githubv4.Int(pr.Number), + "endCursor": (*githubv4.String)(nil), + } + var items ProjectItems + for { + var query response + err := client.Query(repo.RepoHost(), "PullRequestProjectItems", &query, variables) + if err != nil { + return err + } + items.Nodes = append(items.Nodes, query.Repository.PullRequest.ProjectItems.Nodes...) + if !query.Repository.PullRequest.ProjectItems.PageInfo.HasNextPage { + break + } + variables["endCursor"] = githubv4.String(query.Repository.PullRequest.ProjectItems.PageInfo.EndCursor) + } + pr.ProjectItems = items + return nil +} + +// When querying ProjectsV2 fields we generally dont want to show the user +// scope errors and field does not exist errors. ProjectsV2IgnorableError +// checks against known error strings to see if an error can be safely ignored. +// Due to the fact that the GQLClient can return multiple types of errors +// this uses brittle string comparison to check against the known error strings. +func ProjectsV2IgnorableError(err error) bool { + msg := err.Error() + if strings.Contains(msg, errorProjectsV2ReadScope) || + strings.Contains(msg, errorProjectsV2RepositoryField) || + strings.Contains(msg, errorProjectsV2OrganizationField) || + strings.Contains(msg, errorProjectsV2IssueField) || + strings.Contains(msg, errorProjectsV2PullRequestField) { + return true + } + return false +} diff --git a/api/queries_projects_v2_test.go b/api/queries_projects_v2_test.go new file mode 100644 index 000000000..bf6a618ba --- /dev/null +++ b/api/queries_projects_v2_test.go @@ -0,0 +1,266 @@ +package api + +import ( + "errors" + "fmt" + "sort" + "strings" + "testing" + "unicode" + + "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/pkg/httpmock" + "github.com/stretchr/testify/assert" +) + +func TestUpdateProjectV2Items(t *testing.T) { + var tests = []struct { + name string + httpStubs func(*httpmock.Registry) + expectError bool + }{ + { + name: "updates project items", + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation UpdateProjectV2Items\b`), + httpmock.GraphQLQuery(`{"data":{"add_000":{"item":{"id":"1"}},"delete_001":{"item":{"id":"2"}}}}`, + func(mutations string, inputs map[string]interface{}) { + expectedMutations := ` + mutation UpdateProjectV2Items( + $input_000: AddProjectV2ItemByIdInput! + $input_001: AddProjectV2ItemByIdInput! + $input_002: DeleteProjectV2ItemInput! + $input_003: DeleteProjectV2ItemInput! + ) { + add_000: addProjectV2ItemById(input: $input_000) { item { id } } + add_001: addProjectV2ItemById(input: $input_001) { item { id } } + delete_002: deleteProjectV2Item(input: $input_002) { deletedItemId } + delete_003: deleteProjectV2Item(input: $input_003) { deletedItemId } + }` + assert.Equal(t, stripSpace(expectedMutations), stripSpace(mutations)) + if len(inputs) != 4 { + t.Fatalf("expected 4 inputs, got %d", len(inputs)) + } + i0 := inputs["input_000"].(map[string]interface{}) + i1 := inputs["input_001"].(map[string]interface{}) + i2 := inputs["input_002"].(map[string]interface{}) + i3 := inputs["input_003"].(map[string]interface{}) + adds := []string{ + fmt.Sprintf("%v -> %v", i0["contentId"], i0["projectId"]), + fmt.Sprintf("%v -> %v", i1["contentId"], i1["projectId"]), + } + removes := []string{ + fmt.Sprintf("%v x %v", i2["itemId"], i2["projectId"]), + fmt.Sprintf("%v x %v", i3["itemId"], i3["projectId"]), + } + sort.Strings(adds) + sort.Strings(removes) + assert.Equal(t, []string{"item1 -> project1", "item2 -> project2"}, adds) + assert.Equal(t, []string{"item3 x project3", "item4 x project4"}, removes) + })) + }, + }, + { + name: "fails to update project items", + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation UpdateProjectV2Items\b`), + httpmock.GraphQLMutation(`{"data":{}, "errors": [{"message": "some gql error"}]}`, func(inputs map[string]interface{}) {}), + ) + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reg := &httpmock.Registry{} + defer reg.Verify(t) + if tt.httpStubs != nil { + tt.httpStubs(reg) + } + client := newTestClient(reg) + repo, _ := ghrepo.FromFullName("OWNER/REPO") + addProjectItems := map[string]string{"project1": "item1", "project2": "item2"} + deleteProjectItems := map[string]string{"project3": "item3", "project4": "item4"} + err := UpdateProjectV2Items(client, repo, addProjectItems, deleteProjectItems) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestProjectsV2ItemsForIssue(t *testing.T) { + var tests = []struct { + name string + httpStubs func(*httpmock.Registry) + expectItems ProjectItems + expectError bool + }{ + { + name: "retrieves project items for issue", + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`query IssueProjectItems\b`), + httpmock.GraphQLQuery(`{"data":{"repository":{"issue":{"projectItems":{"nodes": [{"id":"projectItem1"},{"id":"projectItem2"}]}}}}}`, + func(query string, inputs map[string]interface{}) {}), + ) + }, + expectItems: ProjectItems{ + Nodes: []*ProjectV2Item{ + {ID: "projectItem1"}, + {ID: "projectItem2"}, + }, + }, + }, + { + name: "fails to retrieve project items for issue", + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`query IssueProjectItems\b`), + httpmock.GraphQLQuery(`{"data":{}, "errors": [{"message": "some gql error"}]}`, + func(query string, inputs map[string]interface{}) {}), + ) + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reg := &httpmock.Registry{} + defer reg.Verify(t) + if tt.httpStubs != nil { + tt.httpStubs(reg) + } + client := newTestClient(reg) + repo, _ := ghrepo.FromFullName("OWNER/REPO") + issue := &Issue{Number: 1} + err := ProjectsV2ItemsForIssue(client, repo, issue) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.expectItems, issue.ProjectItems) + }) + } +} + +func TestProjectsV2ItemsForPullRequest(t *testing.T) { + var tests = []struct { + name string + httpStubs func(*httpmock.Registry) + expectItems ProjectItems + expectError bool + }{ + { + name: "retrieves project items for pull request", + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`query PullRequestProjectItems\b`), + httpmock.GraphQLQuery(`{"data":{"repository":{"pullRequest":{"projectItems":{"nodes": [{"id":"projectItem3"},{"id":"projectItem4"}]}}}}}`, + func(query string, inputs map[string]interface{}) {}), + ) + }, + expectItems: ProjectItems{ + Nodes: []*ProjectV2Item{ + {ID: "projectItem3"}, + {ID: "projectItem4"}, + }, + }, + }, + { + name: "fails to retrieve project items for pull request", + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`query PullRequestProjectItems\b`), + httpmock.GraphQLQuery(`{"data":{}, "errors": [{"message": "some gql error"}]}`, + func(query string, inputs map[string]interface{}) {}), + ) + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reg := &httpmock.Registry{} + defer reg.Verify(t) + if tt.httpStubs != nil { + tt.httpStubs(reg) + } + client := newTestClient(reg) + repo, _ := ghrepo.FromFullName("OWNER/REPO") + pr := &PullRequest{Number: 1} + err := ProjectsV2ItemsForPullRequest(client, repo, pr) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.expectItems, pr.ProjectItems) + }) + } +} + +func TestProjectsV2IgnorableError(t *testing.T) { + var tests = []struct { + name string + errMsg string + expectOut bool + }{ + { + name: "read scope error", + errMsg: "field requires one of the following scopes: ['read:project']", + expectOut: true, + }, + { + name: "repository projectsV2 field error", + errMsg: "Field 'projectsV2' doesn't exist on type 'Repository'", + expectOut: true, + }, + { + name: "organization projectsV2 field error", + errMsg: "Field 'projectsV2' doesn't exist on type 'Organization'", + expectOut: true, + }, + { + name: "issue projectItems field error", + errMsg: "Field 'projectItems' doesn't exist on type 'Issue'", + expectOut: true, + }, + { + name: "pullRequest projectItems field error", + errMsg: "Field 'projectItems' doesn't exist on type 'PullRequest'", + expectOut: true, + }, + { + name: "other error", + errMsg: "some other graphql error message", + expectOut: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := errors.New(tt.errMsg) + out := ProjectsV2IgnorableError(err) + assert.Equal(t, tt.expectOut, out) + }) + } +} + +func stripSpace(str string) string { + var b strings.Builder + b.Grow(len(str)) + for _, ch := range str { + if !unicode.IsSpace(ch) { + b.WriteRune(ch) + } + } + return b.String() +} diff --git a/api/queries_repo.go b/api/queries_repo.go index c5a0f5094..702cf7075 100644 --- a/api/queries_repo.go +++ b/api/queries_repo.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -12,6 +13,7 @@ import ( "time" "github.com/cli/cli/v2/internal/ghinstance" + "golang.org/x/sync/errgroup" "github.com/cli/cli/v2/internal/ghrepo" ghAPI "github.com/cli/go-gh/pkg/api" @@ -43,6 +45,7 @@ type Repository struct { IsSecurityPolicyEnabled bool HasIssuesEnabled bool HasProjectsEnabled bool + HasDiscussionsEnabled bool HasWikiEnabled bool MergeCommitAllowed bool SquashMergeAllowed bool @@ -501,7 +504,7 @@ type repositoryV3 struct { } // ForkRepo forks the repository on GitHub and returns the new repository -func ForkRepo(client *Client, repo ghrepo.Interface, org, newName string) (*Repository, error) { +func ForkRepo(client *Client, repo ghrepo.Interface, org, newName string, defaultBranchOnly bool) (*Repository, error) { path := fmt.Sprintf("repos/%s/forks", ghrepo.FullName(repo)) params := map[string]interface{}{} @@ -511,6 +514,9 @@ func ForkRepo(client *Client, repo ghrepo.Interface, org, newName string) (*Repo if newName != "" { params["name"] = newName } + if defaultBranchOnly { + params["default_branch_only"] = true + } body := &bytes.Buffer{} enc := json.NewEncoder(body) @@ -647,6 +653,7 @@ type RepoMetadataResult struct { AssignableUsers []RepoAssignee Labels []RepoLabel Projects []RepoProject + ProjectsV2 []RepoProjectV2 Milestones []RepoMilestone Teams []OrgTeam } @@ -706,25 +713,52 @@ func (m *RepoMetadataResult) LabelsToIDs(names []string) ([]string, error) { return ids, nil } -func (m *RepoMetadataResult) ProjectsToIDs(names []string) ([]string, error) { +// ProjectsToIDs returns two arrays: +// - the first contains IDs of projects V1 +// - the second contains IDs of projects V2 +// - if neither project V1 or project V2 can be found with a given name, then an error is returned +func (m *RepoMetadataResult) ProjectsToIDs(names []string) ([]string, []string, error) { var ids []string + var idsV2 []string for _, projectName := range names { - found := false - for _, p := range m.Projects { - if strings.EqualFold(projectName, p.Name) { - ids = append(ids, p.ID) - found = true - break - } + id, found := m.projectNameToID(projectName) + if found { + ids = append(ids, id) + continue } - if !found { - return nil, fmt.Errorf("'%s' not found", projectName) + + idV2, found := m.projectV2TitleToID(projectName) + if found { + idsV2 = append(idsV2, idV2) + continue } + + return nil, nil, fmt.Errorf("'%s' not found", projectName) } - return ids, nil + return ids, idsV2, nil } -func ProjectsToPaths(projects []RepoProject, names []string) ([]string, error) { +func (m *RepoMetadataResult) projectNameToID(projectName string) (string, bool) { + for _, p := range m.Projects { + if strings.EqualFold(projectName, p.Name) { + return p.ID, true + } + } + + return "", false +} + +func (m *RepoMetadataResult) projectV2TitleToID(projectTitle string) (string, bool) { + for _, p := range m.ProjectsV2 { + if strings.EqualFold(projectTitle, p.Title) { + return p.ID, true + } + } + + return "", false +} + +func ProjectsToPaths(projects []RepoProject, projectsV2 []RepoProjectV2, names []string) ([]string, error) { var paths []string for _, projectName := range names { found := false @@ -744,6 +778,25 @@ func ProjectsToPaths(projects []RepoProject, names []string) ([]string, error) { break } } + if found { + continue + } + for _, p := range projectsV2 { + if strings.EqualFold(projectName, p.Title) { + // format of ResourcePath: /OWNER/REPO/projects/PROJECT_NUMBER or /orgs/ORG/projects/PROJECT_NUMBER + // required format of path: OWNER/REPO/PROJECT_NUMBER or ORG/PROJECT_NUMBER + var path string + pathParts := strings.Split(p.ResourcePath, "/") + if pathParts[1] == "orgs" { + path = fmt.Sprintf("%s/%s", pathParts[2], pathParts[4]) + } else { + path = fmt.Sprintf("%s/%s/%s", pathParts[1], pathParts[2], pathParts[4]) + } + paths = append(paths, path) + found = true + break + } + } if !found { return nil, fmt.Errorf("'%s' not found", projectName) } @@ -854,6 +907,18 @@ func RepoMetadata(client *Client, repo ghrepo.Interface, input RepoMetadataInput errc <- nil }() } + if input.Projects { + count++ + go func() { + projectsV2, err := RepoAndOrgProjectsV2(client, repo) + if err != nil { + errc <- err + return + } + result.ProjectsV2 = projectsV2 + errc <- nil + }() + } if input.Milestones { count++ go func() { @@ -985,7 +1050,15 @@ type RepoProject struct { ResourcePath string `json:"resourcePath"` } -// RepoProjects fetches all open projects for a repository +type RepoProjectV2 struct { + ID string `json:"id"` + Title string `json:"title"` + Number int `json:"number"` + ResourcePath string `json:"resourcePath"` + Closed bool `json:"closed"` +} + +// RepoProjects fetches all open projects for a repository. func RepoProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, error) { type responseData struct { Repository struct { @@ -1023,23 +1096,87 @@ func RepoProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, error) return projects, nil } -// RepoAndOrgProjects fetches all open projects for a repository and its org +// RepoProjectsV2 fetches all open projectsV2 for a repository. +func RepoProjectsV2(client *Client, repo ghrepo.Interface) ([]RepoProjectV2, error) { + type responseData struct { + Repository struct { + ProjectsV2 struct { + Nodes []RepoProjectV2 + PageInfo struct { + HasNextPage bool + EndCursor string + } + } `graphql:"projectsV2(first: 100, orderBy: {field: TITLE, direction: ASC}, after: $endCursor, query: $query)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } + + variables := map[string]interface{}{ + "owner": githubv4.String(repo.RepoOwner()), + "name": githubv4.String(repo.RepoName()), + "endCursor": (*githubv4.String)(nil), + "query": githubv4.String("is:open"), + } + + var projectsV2 []RepoProjectV2 + for { + var query responseData + err := client.Query(repo.RepoHost(), "RepositoryProjectV2List", &query, variables) + if err != nil { + return nil, err + } + + projectsV2 = append(projectsV2, query.Repository.ProjectsV2.Nodes...) + + if !query.Repository.ProjectsV2.PageInfo.HasNextPage { + break + } + variables["endCursor"] = githubv4.String(query.Repository.ProjectsV2.PageInfo.EndCursor) + } + + return projectsV2, nil +} + +// RepoAndOrgProjects fetches all open projects for a repository and its organization. func RepoAndOrgProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, error) { projects, err := RepoProjects(client, repo) if err != nil { - return projects, fmt.Errorf("error fetching projects: %w", err) + return nil, fmt.Errorf("error fetching projects: %w", err) } orgProjects, err := OrganizationProjects(client, repo) - // TODO: better detection of non-org repos + // TODO: Better detection of non-org repos. if err != nil && !strings.Contains(err.Error(), "Could not resolve to an Organization") { - return projects, fmt.Errorf("error fetching organization projects: %w", err) + return nil, fmt.Errorf("error fetching organization projects: %w", err) } + projects = append(projects, orgProjects...) return projects, nil } +// RepoAndOrgProjectsV2 fetches all open projectsV2 for a repository and its organization. +// Note: If the auth token does not have sufficient scopes or projectsV2 is not supported +// on the host then those errors are swallowed and nil is returned. +func RepoAndOrgProjectsV2(client *Client, repo ghrepo.Interface) ([]RepoProjectV2, error) { + projectsV2, err := RepoProjectsV2(client, repo) + if err != nil { + if ProjectsV2IgnorableError(err) { + return nil, nil + } + + return nil, fmt.Errorf("error fetching projectsV2: %w", err) + } + + orgProjectsV2, err := OrganizationProjectsV2(client, repo) + if err != nil && !strings.Contains(err.Error(), "Could not resolve to an Organization") { + return nil, fmt.Errorf("error fetching organization projectsV2: %w", err) + } + + projectsV2 = append(projectsV2, orgProjectsV2...) + + return projectsV2, nil +} + type RepoAssignee struct { ID string Login string @@ -1192,12 +1329,27 @@ func RepoMilestones(client *Client, repo ghrepo.Interface, state string) ([]Repo } func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []string) ([]string, error) { - var paths []string - projects, err := RepoAndOrgProjects(client, repo) - if err != nil { - return paths, err + g, _ := errgroup.WithContext(context.Background()) + var projects []RepoProject + var projectsV2 []RepoProjectV2 + + g.Go(func() error { + var err error + projects, err = RepoAndOrgProjects(client, repo) + return err + }) + + g.Go(func() error { + var err error + projectsV2, err = RepoAndOrgProjectsV2(client, repo) + return err + }) + + if err := g.Wait(); err != nil { + return nil, err } - return ProjectsToPaths(projects, projectNames) + + return ProjectsToPaths(projects, projectsV2, projectNames) } func CreateRepoTransformToV4(apiClient *Client, hostname string, method string, path string, body io.Reader) (*Repository, error) { diff --git a/api/queries_repo_test.go b/api/queries_repo_test.go index da306de51..15276cc50 100644 --- a/api/queries_repo_test.go +++ b/api/queries_repo_test.go @@ -89,6 +89,17 @@ func Test_RepoMetadata(t *testing.T) { "pageInfo": { "hasNextPage": false } } } } } `)) + http.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "CleanupV2", "id": "CLEANUPV2ID" }, + { "title": "RoadmapV2", "id": "ROADMAPV2ID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) http.Register( httpmock.GraphQL(`query OrganizationProjectList\b`), httpmock.StringResponse(` @@ -99,6 +110,16 @@ func Test_RepoMetadata(t *testing.T) { "pageInfo": { "hasNextPage": false } } } } } `)) + http.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [ + { "title": "TriageV2", "id": "TRIAGEV2ID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) http.Register( httpmock.GraphQL(`query OrganizationTeamList\b`), httpmock.StringResponse(` @@ -149,13 +170,17 @@ func Test_RepoMetadata(t *testing.T) { } expectedProjectIDs := []string{"TRIAGEID", "ROADMAPID"} - projectIDs, err := result.ProjectsToIDs([]string{"triage", "roadmap"}) + expectedProjectV2IDs := []string{"TRIAGEV2ID", "ROADMAPV2ID"} + projectIDs, projectV2IDs, err := result.ProjectsToIDs([]string{"triage", "roadmap", "triagev2", "roadmapv2"}) if err != nil { t.Errorf("error resolving projects: %v", err) } if !sliceEqual(projectIDs, expectedProjectIDs) { t.Errorf("expected projects %v, got %v", expectedProjectIDs, projectIDs) } + if !sliceEqual(projectV2IDs, expectedProjectV2IDs) { + t.Errorf("expected projectsV2 %v, got %v", expectedProjectV2IDs, projectV2IDs) + } expectedMilestoneID := "BIGONEID" milestoneID, err := result.MilestoneToID("big one.oh") @@ -173,15 +198,19 @@ func Test_RepoMetadata(t *testing.T) { } func Test_ProjectsToPaths(t *testing.T) { - expectedProjectPaths := []string{"OWNER/REPO/PROJECT_NUMBER", "ORG/PROJECT_NUMBER"} + expectedProjectPaths := []string{"OWNER/REPO/PROJECT_NUMBER", "ORG/PROJECT_NUMBER", "OWNER/REPO/PROJECT_NUMBER_2"} projects := []RepoProject{ {ID: "id1", Name: "My Project", ResourcePath: "/OWNER/REPO/projects/PROJECT_NUMBER"}, {ID: "id2", Name: "Org Project", ResourcePath: "/orgs/ORG/projects/PROJECT_NUMBER"}, {ID: "id3", Name: "Project", ResourcePath: "/orgs/ORG/projects/PROJECT_NUMBER_2"}, } - projectNames := []string{"My Project", "Org Project"} + projectsV2 := []RepoProjectV2{ + {ID: "id4", Title: "My Project V2", ResourcePath: "/OWNER/REPO/projects/PROJECT_NUMBER_2"}, + {ID: "id5", Title: "Org Project V2", ResourcePath: "/orgs/ORG/projects/PROJECT_NUMBER_3"}, + } + projectNames := []string{"My Project", "Org Project", "My Project V2"} - projectPaths, err := ProjectsToPaths(projects, projectNames) + projectPaths, err := ProjectsToPaths(projects, projectsV2, projectNames) if err != nil { t.Errorf("error resolving projects: %v", err) } @@ -210,20 +239,41 @@ func Test_ProjectNamesToPaths(t *testing.T) { http.Register( httpmock.GraphQL(`query OrganizationProjectList\b`), httpmock.StringResponse(` - { "data": { "organization": { "projects": { - "nodes": [ - { "name": "Triage", "id": "TRIAGEID", "resourcePath": "/orgs/ORG/projects/1" } - ], - "pageInfo": { "hasNextPage": false } - } } } } - `)) + { "data": { "organization": { "projects": { + "nodes": [ + { "name": "Triage", "id": "TRIAGEID", "resourcePath": "/orgs/ORG/projects/1" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + http.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "CleanupV2", "id": "CLEANUPV2ID", "resourcePath": "/OWNER/REPO/projects/3" }, + { "title": "RoadmapV2", "id": "ROADMAPV2ID", "resourcePath": "/OWNER/REPO/projects/4" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + http.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [ + { "title": "TriageV2", "id": "TRIAGEV2ID", "resourcePath": "/orgs/ORG/projects/2" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) - projectPaths, err := ProjectNamesToPaths(client, repo, []string{"Triage", "Roadmap"}) + projectPaths, err := ProjectNamesToPaths(client, repo, []string{"Triage", "Roadmap", "TriageV2", "RoadmapV2"}) if err != nil { t.Fatalf("unexpected error: %v", err) } - expectedProjectPaths := []string{"ORG/1", "OWNER/REPO/2"} + expectedProjectPaths := []string{"ORG/1", "OWNER/REPO/2", "ORG/2", "OWNER/REPO/4"} if !sliceEqual(projectPaths, expectedProjectPaths) { t.Errorf("expected projects paths %v, got %v", expectedProjectPaths, projectPaths) } diff --git a/api/query_builder.go b/api/query_builder.go index ce7b1791b..65ef6d3f7 100644 --- a/api/query_builder.go +++ b/api/query_builder.go @@ -263,7 +263,7 @@ func IssueGraphQL(fields []string) string { case "author": q = append(q, `author{login,...on User{id,name}}`) case "mergedBy": - q = append(q, `mergedBy{login}`) + q = append(q, `mergedBy{login,...on User{id,name}}`) case "headRepositoryOwner": q = append(q, `headRepositoryOwner{id,login,...on User{name}}`) case "headRepository": @@ -274,6 +274,8 @@ func IssueGraphQL(fields []string) string { q = append(q, `labels(first:100){nodes{id,name,description,color},totalCount}`) case "projectCards": q = append(q, `projectCards(first:100){nodes{project{name}column{name}},totalCount}`) + case "projectItems": + q = append(q, `projectItems(first:100){nodes{id, project{id,title}},totalCount}`) case "milestone": q = append(q, `milestone{number,title,description,dueOn}`) case "reactionGroups": @@ -346,6 +348,7 @@ var RepositoryFields = []string{ "hasIssuesEnabled", "hasProjectsEnabled", "hasWikiEnabled", + "hasDiscussionsEnabled", "mergeCommitAllowed", "squashMergeAllowed", "rebaseMergeAllowed", diff --git a/api/sanitize_ascii.go b/api/sanitize_ascii.go new file mode 100644 index 000000000..6033a07a6 --- /dev/null +++ b/api/sanitize_ascii.go @@ -0,0 +1,195 @@ +package api + +import ( + "bytes" + "errors" + "io" + "net/http" + "regexp" + "strings" +) + +var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) + +// GitHub servers return non-printable characters as their unicode code point values. +// The values of \u0000 to \u001F represent C0 ASCII control characters and +// the values of \u0080 to \u009F represent C1 ASCII control characters. These control +// characters will be interpreted by the terminal, this behaviour can be used maliciously +// as an attack vector, especially the control character \u001B. This function wraps +// JSON response bodies in a ReadCloser that transforms C0 and C1 control characters +// to their caret and hex notations respectively so that the terminal will not interpret them. +func AddASCIISanitizer(rt http.RoundTripper) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + res, err := rt.RoundTrip(req) + if err != nil || !jsonTypeRE.MatchString(res.Header.Get("Content-Type")) { + return res, err + } + res.Body = &sanitizeASCIIReadCloser{ReadCloser: res.Body} + return res, err + }} +} + +// sanitizeASCIIReadCloser implements the ReadCloser interface. +type sanitizeASCIIReadCloser struct { + io.ReadCloser + addEscape bool + remainder []byte +} + +// Read uses a sliding window alogorithm to detect C0 and C1 +// ASCII control sequences as they are read and replaces them +// with equivelent inert characters. Characters that are not part +// of a control sequence not modified. +func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) { + var bufIndex, outIndex int + outLen := len(out) + buf := make([]byte, outLen) + + bufLen, readErr := s.ReadCloser.Read(buf) + if readErr != nil && !errors.Is(readErr, io.EOF) { + if bufLen > 0 { + // Do not sanitize if there was a read error that is not EOF. + bufLen = copy(out, buf) + } + return bufLen, readErr + } + buf = buf[:bufLen] + + if s.remainder != nil { + buf = append(s.remainder, buf...) + bufLen += len(s.remainder) + s.remainder = s.remainder[:0] + } + + for bufIndex < bufLen-6 && outIndex < outLen { + window := buf[bufIndex : bufIndex+6] + + if bytes.HasPrefix(window, []byte(`\u00`)) { + repl, _ := mapControlCharacterToCaret(window) + if s.addEscape { + repl = append([]byte{'\\'}, repl...) + s.addEscape = false + } + for j := 0; j < len(repl); j++ { + if outIndex < outLen { + out[outIndex] = repl[j] + outIndex++ + } else { + s.remainder = append(s.remainder, repl[j]) + } + } + bufIndex += 6 + continue + } + + if window[0] == '\\' { + s.addEscape = !s.addEscape + } else { + s.addEscape = false + } + + out[outIndex] = buf[bufIndex] + outIndex++ + bufIndex++ + } + + if readErr != nil && errors.Is(readErr, io.EOF) { + remaining := bufLen - bufIndex + for j := 0; j < remaining; j++ { + if outIndex < outLen { + out[outIndex] = buf[bufIndex] + outIndex++ + bufIndex++ + } else { + s.remainder = append(s.remainder, buf[bufIndex]) + bufIndex++ + } + } + } else { + if bufIndex < bufLen { + s.remainder = append(s.remainder, buf[bufIndex:]...) + } + } + + if len(s.remainder) != 0 { + readErr = nil + } + + return outIndex, readErr +} + +// mapControlCharacterToCaret maps C0 control sequences to caret notation +// and C1 control sequences to hex notation. C1 control sequences do not +// have caret notation representation. +func mapControlCharacterToCaret(b []byte) ([]byte, bool) { + m := map[string]string{ + `\u0000`: `^@`, + `\u0001`: `^A`, + `\u0002`: `^B`, + `\u0003`: `^C`, + `\u0004`: `^D`, + `\u0005`: `^E`, + `\u0006`: `^F`, + `\u0007`: `^G`, + `\u0008`: `^H`, + `\u0009`: `^I`, + `\u000a`: `^J`, + `\u000b`: `^K`, + `\u000c`: `^L`, + `\u000d`: `^M`, + `\u000e`: `^N`, + `\u000f`: `^O`, + `\u0010`: `^P`, + `\u0011`: `^Q`, + `\u0012`: `^R`, + `\u0013`: `^S`, + `\u0014`: `^T`, + `\u0015`: `^U`, + `\u0016`: `^V`, + `\u0017`: `^W`, + `\u0018`: `^X`, + `\u0019`: `^Y`, + `\u001a`: `^Z`, + `\u001b`: `^[`, + `\u001c`: `^\\`, + `\u001d`: `^]`, + `\u001e`: `^^`, + `\u001f`: `^_`, + `\u0080`: `\\200`, + `\u0081`: `\\201`, + `\u0082`: `\\202`, + `\u0083`: `\\203`, + `\u0084`: `\\204`, + `\u0085`: `\\205`, + `\u0086`: `\\206`, + `\u0087`: `\\207`, + `\u0088`: `\\210`, + `\u0089`: `\\211`, + `\u008a`: `\\212`, + `\u008b`: `\\213`, + `\u008c`: `\\214`, + `\u008d`: `\\215`, + `\u008e`: `\\216`, + `\u008f`: `\\217`, + `\u0090`: `\\220`, + `\u0091`: `\\221`, + `\u0092`: `\\222`, + `\u0093`: `\\223`, + `\u0094`: `\\224`, + `\u0095`: `\\225`, + `\u0096`: `\\226`, + `\u0097`: `\\227`, + `\u0098`: `\\230`, + `\u0099`: `\\231`, + `\u009a`: `\\232`, + `\u009b`: `\\233`, + `\u009c`: `\\234`, + `\u009d`: `\\235`, + `\u009e`: `\\236`, + `\u009f`: `\\237`, + } + if c, ok := m[strings.ToLower(string(b))]; ok { + return []byte(c), true + } + return b, false +} diff --git a/api/sanitize_ascii_test.go b/api/sanitize_ascii_test.go new file mode 100644 index 000000000..ff43f9287 --- /dev/null +++ b/api/sanitize_ascii_test.go @@ -0,0 +1,62 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + "testing/iotest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHTTPClient_SanitizeASCIIControlCharacters(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + issue := Issue{ + Title: "\u001B[31mRed Title\u001B[0m", + Body: "1\u0001 2\u0002 3\u0003 4\u0004 5\u0005 6\u0006 7\u0007 8\u0008 9\t A\r\n B\u000b C\u000c D\r\n E\u000e F\u000f", + Author: Author{ + ID: "1", + Name: "10\u0010 11\u0011 12\u0012 13\u0013 14\u0014 15\u0015 16\u0016 17\u0017 18\u0018 19\u0019 1A\u001a 1B\u001b 1C\u001c 1D\u001d 1E\u001e 1F\u001f", + Login: "monalisa", + }, + ActiveLockReason: "Escaped \u001B \\u001B \\\u001B \\\\u001B", + } + responseData, _ := json.Marshal(issue) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + fmt.Fprint(w, string(responseData)) + })) + defer ts.Close() + + client, err := NewHTTPClient(HTTPClientOptions{}) + require.NoError(t, err) + req, err := http.NewRequest("GET", ts.URL, nil) + require.NoError(t, err) + res, err := client.Do(req) + require.NoError(t, err) + body, err := io.ReadAll(res.Body) + res.Body.Close() + require.NoError(t, err) + var issue Issue + err = json.Unmarshal(body, &issue) + require.NoError(t, err) + assert.Equal(t, "^[[31mRed Title^[[0m", issue.Title) + assert.Equal(t, "1^A 2^B 3^C 4^D 5^E 6^F 7^G 8^H 9\t A\r\n B^K C^L D\r\n E^N F^O", issue.Body) + assert.Equal(t, "10^P 11^Q 12^R 13^S 14^T 15^U 16^V 17^W 18^X 19^Y 1A^Z 1B^[ 1C^\\ 1D^] 1E^^ 1F^_", issue.Author.Name) + assert.Equal(t, "monalisa", issue.Author.Login) + assert.Equal(t, "Escaped ^[ \\^[ \\^[ \\\\^[", issue.ActiveLockReason) +} + +func TestSanitizeASCIIReadCloser(t *testing.T) { + data := []byte(`"Assign},"L`) + var r io.Reader = bytes.NewReader(data) + r = &sanitizeASCIIReadCloser{ReadCloser: io.NopCloser(r)} + r = iotest.OneByteReader(r) + out, err := io.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, data, out) +} diff --git a/cmd/gh/main.go b/cmd/gh/main.go index 077e91995..eb0ad0672 100644 --- a/cmd/gh/main.go +++ b/cmd/gh/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "io" @@ -53,17 +54,24 @@ func main() { func mainRun() exitCode { buildDate := build.Date buildVersion := build.Version - - updateMessageChan := make(chan *update.ReleaseInfo) - go func() { - rel, _ := checkForUpdate(buildVersion) - updateMessageChan <- rel - }() - hasDebug, _ := utils.IsDebugEnabled() cmdFactory := factory.New(buildVersion) stderr := cmdFactory.IOStreams.ErrOut + + ctx := context.Background() + + updateCtx, updateCancel := context.WithCancel(ctx) + defer updateCancel() + updateMessageChan := make(chan *update.ReleaseInfo) + go func() { + rel, err := checkForUpdate(updateCtx, cmdFactory, buildVersion) + if err != nil && hasDebug { + fmt.Fprintf(stderr, "warning: checking for update failed: %v", err) + } + updateMessageChan <- rel + }() + if !cmdFactory.IOStreams.ColorEnabled() { surveyCore.DisableColor = true ansi.DisableColors(true) @@ -209,7 +217,7 @@ func mainRun() exitCode { rootCmd.SetArgs(expandedArgs) - if cmd, err := rootCmd.ExecuteC(); err != nil { + if cmd, err := rootCmd.ExecuteContextC(ctx); err != nil { var pagerPipeError *iostreams.ErrClosedPagerPipe var noResultsError cmdutil.NoResultsError if err == cmdutil.SilentError { @@ -257,6 +265,7 @@ func mainRun() exitCode { return exitError } + updateCancel() // if the update checker hasn't completed by now, abort it newRelease := <-updateMessageChan if newRelease != nil { isHomebrew := isUnderHomebrew(cmdFactory.Executable()) @@ -348,21 +357,17 @@ func isCI() bool { os.Getenv("RUN_ID") != "" // TaskCluster, dsari } -func checkForUpdate(currentVersion string) (*update.ReleaseInfo, error) { +func checkForUpdate(ctx context.Context, f *cmdutil.Factory, currentVersion string) (*update.ReleaseInfo, error) { if !shouldCheckForUpdate() { return nil, nil } - httpClient, err := api.NewHTTPClient(api.HTTPClientOptions{ - AppVersion: currentVersion, - Log: os.Stderr, - }) + httpClient, err := f.HttpClient() if err != nil { return nil, err } - client := api.NewClientFromHTTP(httpClient) repo := updaterEnabled stateFilePath := filepath.Join(config.StateDir(), "state.yml") - return update.CheckForUpdate(client, stateFilePath, repo, currentVersion) + return update.CheckForUpdate(ctx, httpClient, stateFilePath, repo, currentVersion) } func isRecentRelease(publishedAt time.Time) bool { diff --git a/context/context.go b/context/context.go index dc6407a67..7030a557f 100644 --- a/context/context.go +++ b/context/context.go @@ -11,9 +11,9 @@ import ( "github.com/cli/cli/v2/pkg/iostreams" ) -// cap the number of git remotes looked up, since the user might have an -// unusually large number of git remotes -const maxRemotesForLookup = 5 +// Cap the number of git remotes to look up, since the user might have an +// unusually large number of git remotes. +const defaultRemotesForLookup = 5 func ResolveRemotesToRepos(remotes Remotes, client *api.Client, base string) (*ResolvedRemotes, error) { sort.Stable(remotes) @@ -36,11 +36,11 @@ func ResolveRemotesToRepos(remotes Remotes, client *api.Client, base string) (*R return result, nil } -func resolveNetwork(result *ResolvedRemotes) error { +func resolveNetwork(result *ResolvedRemotes, remotesForLookup int) error { var repos []ghrepo.Interface for _, r := range result.remotes { repos = append(repos, r) - if len(repos) == maxRemotesForLookup { + if len(repos) == remotesForLookup { break } } @@ -84,7 +84,7 @@ func (r *ResolvedRemotes) BaseRepo(io *iostreams.IOStreams) (ghrepo.Interface, e return r.remotes[0], nil } - repos, err := r.NetworkRepos() + repos, err := r.NetworkRepos(defaultRemotesForLookup) if err != nil { return nil, err } @@ -109,7 +109,7 @@ func (r *ResolvedRemotes) BaseRepo(io *iostreams.IOStreams) (ghrepo.Interface, e func (r *ResolvedRemotes) HeadRepos() ([]*api.Repository, error) { if r.network == nil { - err := resolveNetwork(r) + err := resolveNetwork(r, defaultRemotesForLookup) if err != nil { return nil, err } @@ -124,9 +124,11 @@ func (r *ResolvedRemotes) HeadRepos() ([]*api.Repository, error) { return results, nil } -func (r *ResolvedRemotes) NetworkRepos() ([]*api.Repository, error) { +// NetworkRepos fetches info about remotes for the network of repos. +// Pass a value of 0 to fetch info on all remotes. +func (r *ResolvedRemotes) NetworkRepos(remotesForLookup int) ([]*api.Repository, error) { if r.network == nil { - err := resolveNetwork(r) + err := resolveNetwork(r, remotesForLookup) if err != nil { return nil, err } diff --git a/docs/working-with-us.md b/docs/working-with-us.md new file mode 100644 index 000000000..19aa6e171 --- /dev/null +++ b/docs/working-with-us.md @@ -0,0 +1,55 @@ +# Working with the GitHub CLI Team: Hubber Edition + +POV: your team at GitHub is interested in shipping a new command in `gh`. + +This document outlines the process the CLI team prefers for helping ensure success both for your new feature and the CLI project as a whole. + +## Step 0: Create an extension + +Even if you want to see your code merged into `gh`, you should start with [an extension](https://docs.github.com/en/github-cli/github-cli/creating-github-cli-extensions) written in Go and leveraging [go-gh](https://github.com/cli/go-gh). Though `gh` extensions can be written in any language, we treat Go as a first class experience and ship a library of helpers for extensions written in Go. + +Creating an extension enables you to start prototyping immediately, without waiting for us, and gives us something tangible to review if you decide you'd like the work incorporated into `gh`. It also means that you can decide to simply release your work without waiting for us to merge it, which leaves you in charge of release scheduling moving forward. + +If you know from this point that you're comfortable with your new feature being an extension, don't worry about the rest of this document. We don't dictate how people create and release `gh` extensions. + +If you do want your feature merged into `gh`, read on. + +## Step 1: UX review + +No matter what state your code is in, open up an issue either in [the open source cli/cli repository](https://github.com/cli/cli) or, if you'd rather not make the new feature public yet, [the closed github/cli repository](https://github.com/github/cli). + +Describe how your new command would be used. Include mock-up examples, including a mock-up of what usage information would be printed if a user ran your command with `--help`. + +We take this step seriously because we believe in keeping `gh`'s interface consistent and intuitive. + +## Step 2: Beta + +Once we've signed off on the proposed UX on the issue opened in step 1, develop your extension to at least beta quality. It's up to you if you actually want to go through a beta release phase with real users or not. + +## Step 3: Merge or no merge + +With a beta in hand it's time to decide whether or not to mainline your extension into the `trunk` of `gh`. Some questions to consider: + +- How complex is the support burden for your feature? + +If this feature requires extensive or specialized support, you will either need to release it as an extension or work with the CLI team to get maintainer access to `cli/cli`. The CLI team is very small and cannot promise any kind of SLA for supporting your work. For example, the `gh cs` command is sufficiently specialized and complex that we have given the `codespaces` team write access to the repository to maintain their own pull request review process. We have not put it in an extension as Codespaces are a core GitHub product with widespread use among our users. + +- What kind of release cadence do you want? + +We do a `gh` release roughly every other week, but if the changeset for a given week is light we may skip one. We make no official promise as to our cadence, and while we do have an on-call rotation there is no guarantee that you'll be able to get emergency fixes out within hours. If this is troubling, consider keeping your work in an extension. + +- What kind of audience are you trying to reach? + +Is this new feature intended for all GitHub users or just a few? If it's as applicable to your average GitHub user or customer as something like Codespaces or Pull Requests, that's a strong indication it should be merged into `trunk`. If not, consider keeping it an extension. + +If after all of this consideration you think your feature should be merged, please open an issue in [cli/cli](https://github.com/cli/cli) with a link to your extension's code. It will go into our triage queue and we'll confirm that merging into `trunk` is feasible and appropriate. + +## Step 4 + +Once we've signed off, open up a pull request in [cli/cli](https://github.com/cli/cli) adding your command. Since we make use of `go-gh` within our code already, it shouldn't be too onerous to make your extension merge-able. Link to the issue you opened in step 3 so we have some context on the pull request. + +## Other considerations + +- If you have a high need for secrecy until the point of release, let us know in [#cli on slack](https://github.slack.com/archives/CLLG3RMAR). We'll come up with a solution to work on merging your command in private. +- We are a highly asynchronous team due to wide timezone differences. The best way to get in touch with us is via issue and pull request comments to which we'll respond within 24 hours. You can ping us on Slack but that's generally not our preference. +- We are happy to pair with you on extension authoring! Just let us know if we can provide guidance and we can schedule synchronous time to work together with you. diff --git a/git/client.go b/git/client.go index c9533094f..b45ce9138 100644 --- a/git/client.go +++ b/git/client.go @@ -20,6 +20,10 @@ import ( var remoteRE = regexp.MustCompile(`(.+)\s+(.+)\s+\((push|fetch)\)`) +type errWithExitCode interface { + ExitCode() int +} + type Client struct { GhPath string RepoDir string @@ -390,7 +394,10 @@ func (c *Client) revParse(ctx context.Context, args ...string) ([]byte, error) { // Below are commands that make network calls and need authentication credentials supplied from gh. func (c *Client) Fetch(ctx context.Context, remote string, refspec string, mods ...CommandModifier) error { - args := []string{"fetch", remote, refspec} + args := []string{"fetch", remote} + if refspec != "" { + args = append(args, refspec) + } cmd, err := c.AuthenticatedCommand(ctx, args...) if err != nil { return err @@ -458,8 +465,8 @@ func (c *Client) AddRemote(ctx context.Context, name, urlStr string, trackingBra for _, branch := range trackingBranches { args = append(args, "-t", branch) } - args = append(args, "-f", name, urlStr) - cmd, err := c.AuthenticatedCommand(ctx, args...) + args = append(args, name, urlStr) + cmd, err := c.Command(ctx, args...) if err != nil { return nil, err } @@ -489,21 +496,16 @@ func (c *Client) AddRemote(ctx context.Context, name, urlStr string, trackingBra return remote, nil } -func (c *Client) InGitDirectory(ctx context.Context) bool { - showCmd, err := c.Command(ctx, "rev-parse", "--is-inside-work-tree") +func (c *Client) IsLocalGitRepo(ctx context.Context) (bool, error) { + _, err := c.GitDir(ctx) if err != nil { - return false + var execError errWithExitCode + if errors.As(err, &execError) && execError.ExitCode() == 128 { + return false, nil + } + return false, err } - out, err := showCmd.Output() - if err != nil { - return false - } - - split := strings.Split(string(out), "\n") - if len(split) > 0 { - return split[0] == "true" - } - return false + return true, nil } func (c *Client) UnsetRemoteResolution(ctx context.Context, name string) error { diff --git a/git/client_test.go b/git/client_test.go index e12789c8c..68cfc2739 100644 --- a/git/client_test.go +++ b/git/client_test.go @@ -1089,7 +1089,7 @@ func TestClientAddRemote(t *testing.T) { url: "URL", dir: "DIRECTORY", branches: []string{}, - wantCmdArgs: `path/to/git -C DIRECTORY -c credential.helper= -c credential.helper=!"gh" auth git-credential remote add -f test URL`, + wantCmdArgs: `path/to/git -C DIRECTORY remote add test URL`, }, { title: "fetch specific branches only", @@ -1097,7 +1097,7 @@ func TestClientAddRemote(t *testing.T) { url: "URL", dir: "DIRECTORY", branches: []string{"trunk", "dev"}, - wantCmdArgs: `path/to/git -C DIRECTORY -c credential.helper= -c credential.helper=!"gh" auth git-credential remote add -t trunk -t dev -f test URL`, + wantCmdArgs: `path/to/git -C DIRECTORY remote add -t trunk -t dev test URL`, }, } for _, tt := range tests { diff --git a/go.mod b/go.mod index 2b11dc420..fcba58a19 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/charmbracelet/glamour v0.5.1-0.20220727184942-e70ff2d969da github.com/charmbracelet/lipgloss v0.5.0 github.com/cli/go-gh v1.0.0 - github.com/cli/oauth v0.9.0 + github.com/cli/oauth v1.0.1 github.com/cli/safeexec v1.0.1 github.com/cpuguy83/go-md2man/v2 v2.0.2 github.com/creack/pty v1.1.18 @@ -22,7 +22,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-version v1.3.0 github.com/henvic/httpretty v0.0.6 - github.com/joho/godotenv v1.4.0 + github.com/joho/godotenv v1.5.1 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.17 @@ -81,3 +81,5 @@ require ( ) replace golang.org/x/crypto => github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 + +replace github.com/henvic/httpretty v0.0.6 => github.com/mislav/httpretty v0.1.1-0.20230202151216-d31343e0d884 diff --git a/go.sum b/go.sum index a17425bfb..03be7b86e 100644 --- a/go.sum +++ b/go.sum @@ -62,8 +62,8 @@ github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 h1:3f4uHLfWx4/WlnMPXGai github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= github.com/cli/go-gh v1.0.0 h1:zE1YUAUYqGXNZuICEBeOkIMJ5F50BS0ftvtoWGlsEFI= github.com/cli/go-gh v1.0.0/go.mod h1:bqxLdCoTZ73BuiPEJx4olcO/XKhVZaFDchFagYRBweE= -github.com/cli/oauth v0.9.0 h1:nxBC0Df4tUzMkqffAB+uZvisOwT3/N9FpkfdTDtafxc= -github.com/cli/oauth v0.9.0/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4= +github.com/cli/oauth v1.0.1 h1:pXnTFl/qUegXHK531Dv0LNjW4mLx626eS42gnzfXJPA= +github.com/cli/oauth v1.0.1/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4= github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q= github.com/cli/safeexec v1.0.1 h1:e/C79PbXF4yYTN/wauC4tviMxEV13BwljGj0N9j+N00= github.com/cli/safeexec v1.0.1/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q= @@ -164,8 +164,6 @@ github.com/hashicorp/go-version v1.3.0 h1:McDWVJIU/y+u1BRV06dPaLfLCaT7fUTJLp5r04 github.com/hashicorp/go-version v1.3.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/henvic/httpretty v0.0.6 h1:JdzGzKZBajBfnvlMALXXMVQWxWMF/ofTy8C3/OSUTxs= -github.com/henvic/httpretty v0.0.6/go.mod h1:X38wLjWXHkXT7r2+uK8LjCMne9rsuNaBLJ+5cU2/Pmo= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= @@ -175,8 +173,8 @@ github.com/itchyny/gojq v0.12.8 h1:Zxcwq8w4IeR8JJYEtoG2MWJZUv0RGY6QqJcO1cqV8+A= github.com/itchyny/gojq v0.12.8/go.mod h1:gE2kZ9fVRU0+JAksaTzjIlgnCa2akU+a1V0WXgJQN5c= github.com/itchyny/timefmt-go v0.1.3 h1:7M3LGVDsqcd0VZH2U+x393obrzZisp7C0uEe921iRkU= github.com/itchyny/timefmt-go v0.1.3/go.mod h1:0osSSCQSASBJMsIZnhAaF1C2fCBTJZXrnj37mG8/c+A= -github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= -github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= @@ -209,6 +207,8 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyex github.com/microcosm-cc/bluemonday v1.0.19/go.mod h1:QNzV2UbLK2/53oIIwTOyLUSABMkjZ4tqiyC1g/DyqxE= github.com/microcosm-cc/bluemonday v1.0.20 h1:flpzsq4KU3QIYAYGV/szUat7H+GPOXR0B2JU5A1Wp8Y= github.com/microcosm-cc/bluemonday v1.0.20/go.mod h1:yfBmMi8mxvaZut3Yytv+jTXRY8mxyjJ0/kQBTElld50= +github.com/mislav/httpretty v0.1.1-0.20230202151216-d31343e0d884 h1:JQp1j1IWuMQZc2tyDQ9KmksjQbw5MhUOzWzZZn7WyU0= +github.com/mislav/httpretty v0.1.1-0.20230202151216-d31343e0d884/go.mod h1:ViEsly7wgdugYtymX54pYp6Vv2wqZmNHayJ6q8tlKCc= github.com/muesli/reflow v0.2.1-0.20210115123740-9e1d0d53df68/go.mod h1:Xk+z4oIWdQqJzsxyjgl3P22oYZnHdZ8FFTHAQQt5BMQ= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= @@ -259,6 +259,7 @@ github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.4/go.mod h1:rmuwmfZ0+bvzB24eSC//bk1R1Zp3hM0OXYv/G2LIilg= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -299,6 +300,7 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -327,6 +329,7 @@ golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -347,6 +350,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -382,6 +386,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -446,6 +451,8 @@ golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 2dc81ba64..5c85ae616 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "time" "github.com/cenkalti/backoff/v4" @@ -70,7 +71,6 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session defer progress.StopProgressIndicator() return liveshare.Connect(ctx, liveshare.Options{ - ClientName: "gh", SessionID: codespace.Connection.SessionID, SessionToken: codespace.Connection.SessionToken, RelaySAS: codespace.Connection.RelaySAS, @@ -79,3 +79,18 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session Logger: sessionLogger, }) } + +// ListenTCP starts a localhost tcp listener and returns the listener and bound port +func ListenTCP(port int) (*net.TCPListener, int, error) { + addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + return nil, 0, fmt.Errorf("failed to build tcp address: %w", err) + } + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, 0, fmt.Errorf("failed to listen to local port over tcp: %w", err) + } + port = listener.Addr().(*net.TCPAddr).Port + + return listener, port, nil +} diff --git a/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go b/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go index 1620de348..6da7f9e39 100644 --- a/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go +++ b/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.0 +// protoc-gen-go v1.28.1 // protoc v3.21.12 // source: codespace/codespace_host_service.v1.proto @@ -20,6 +20,116 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type NotifyCodespaceOfClientActivityRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ClientId string `protobuf:"bytes,1,opt,name=ClientId,proto3" json:"ClientId,omitempty"` + ClientActivities []string `protobuf:"bytes,2,rep,name=ClientActivities,proto3" json:"ClientActivities,omitempty"` +} + +func (x *NotifyCodespaceOfClientActivityRequest) Reset() { + *x = NotifyCodespaceOfClientActivityRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *NotifyCodespaceOfClientActivityRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*NotifyCodespaceOfClientActivityRequest) ProtoMessage() {} + +func (x *NotifyCodespaceOfClientActivityRequest) ProtoReflect() protoreflect.Message { + mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use NotifyCodespaceOfClientActivityRequest.ProtoReflect.Descriptor instead. +func (*NotifyCodespaceOfClientActivityRequest) Descriptor() ([]byte, []int) { + return file_codespace_codespace_host_service_v1_proto_rawDescGZIP(), []int{0} +} + +func (x *NotifyCodespaceOfClientActivityRequest) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *NotifyCodespaceOfClientActivityRequest) GetClientActivities() []string { + if x != nil { + return x.ClientActivities + } + return nil +} + +type NotifyCodespaceOfClientActivityResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Result bool `protobuf:"varint,1,opt,name=Result,proto3" json:"Result,omitempty"` + Message string `protobuf:"bytes,2,opt,name=Message,proto3" json:"Message,omitempty"` +} + +func (x *NotifyCodespaceOfClientActivityResponse) Reset() { + *x = NotifyCodespaceOfClientActivityResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *NotifyCodespaceOfClientActivityResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*NotifyCodespaceOfClientActivityResponse) ProtoMessage() {} + +func (x *NotifyCodespaceOfClientActivityResponse) ProtoReflect() protoreflect.Message { + mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use NotifyCodespaceOfClientActivityResponse.ProtoReflect.Descriptor instead. +func (*NotifyCodespaceOfClientActivityResponse) Descriptor() ([]byte, []int) { + return file_codespace_codespace_host_service_v1_proto_rawDescGZIP(), []int{1} +} + +func (x *NotifyCodespaceOfClientActivityResponse) GetResult() bool { + if x != nil { + return x.Result + } + return false +} + +func (x *NotifyCodespaceOfClientActivityResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + type RebuildContainerRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -31,7 +141,7 @@ type RebuildContainerRequest struct { func (x *RebuildContainerRequest) Reset() { *x = RebuildContainerRequest{} if protoimpl.UnsafeEnabled { - mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[0] + mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -44,7 +154,7 @@ func (x *RebuildContainerRequest) String() string { func (*RebuildContainerRequest) ProtoMessage() {} func (x *RebuildContainerRequest) ProtoReflect() protoreflect.Message { - mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[0] + mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -57,7 +167,7 @@ func (x *RebuildContainerRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RebuildContainerRequest.ProtoReflect.Descriptor instead. func (*RebuildContainerRequest) Descriptor() ([]byte, []int) { - return file_codespace_codespace_host_service_v1_proto_rawDescGZIP(), []int{0} + return file_codespace_codespace_host_service_v1_proto_rawDescGZIP(), []int{2} } func (x *RebuildContainerRequest) GetIncremental() bool { @@ -78,7 +188,7 @@ type RebuildContainerResponse struct { func (x *RebuildContainerResponse) Reset() { *x = RebuildContainerResponse{} if protoimpl.UnsafeEnabled { - mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[1] + mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -91,7 +201,7 @@ func (x *RebuildContainerResponse) String() string { func (*RebuildContainerResponse) ProtoMessage() {} func (x *RebuildContainerResponse) ProtoReflect() protoreflect.Message { - mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[1] + mi := &file_codespace_codespace_host_service_v1_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -104,7 +214,7 @@ func (x *RebuildContainerResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RebuildContainerResponse.ProtoReflect.Descriptor instead. func (*RebuildContainerResponse) Descriptor() ([]byte, []int) { - return file_codespace_codespace_host_service_v1_proto_rawDescGZIP(), []int{1} + return file_codespace_codespace_host_service_v1_proto_rawDescGZIP(), []int{3} } func (x *RebuildContainerResponse) GetRebuildContainer() bool { @@ -122,29 +232,54 @@ var file_codespace_codespace_host_service_v1_proto_rawDesc = []byte{ 0x63, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x27, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x73, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x48, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x2e, 0x76, 0x31, 0x22, 0x50, 0x0a, 0x17, 0x52, 0x65, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x43, - 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x25, 0x0a, 0x0b, 0x49, 0x6e, 0x63, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x0b, 0x49, 0x6e, 0x63, 0x72, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x61, 0x6c, 0x88, 0x01, 0x01, 0x42, 0x0e, 0x0a, 0x0c, 0x5f, 0x49, 0x6e, 0x63, 0x72, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x22, 0x46, 0x0a, 0x18, 0x52, 0x65, 0x62, 0x75, 0x69, 0x6c, - 0x64, 0x43, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x2a, 0x0a, 0x10, 0x52, 0x65, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x43, 0x6f, 0x6e, - 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x52, 0x65, - 0x62, 0x75, 0x69, 0x6c, 0x64, 0x43, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x32, 0xae, - 0x01, 0x0a, 0x0d, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x48, 0x6f, 0x73, 0x74, - 0x12, 0x9c, 0x01, 0x0a, 0x15, 0x52, 0x65, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x43, 0x6f, 0x6e, 0x74, - 0x61, 0x69, 0x6e, 0x65, 0x72, 0x41, 0x73, 0x79, 0x6e, 0x63, 0x12, 0x40, 0x2e, 0x43, 0x6f, 0x64, - 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x73, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x2e, 0x43, 0x6f, 0x64, - 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x48, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x43, 0x6f, 0x6e, 0x74, - 0x61, 0x69, 0x6e, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x41, 0x2e, 0x43, - 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x73, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x2e, 0x43, - 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x48, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x43, 0x6f, - 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, - 0x0d, 0x5a, 0x0b, 0x2e, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x2e, 0x76, 0x31, 0x22, 0x70, 0x0a, 0x26, 0x4e, 0x6f, 0x74, 0x69, 0x66, 0x79, 0x43, 0x6f, + 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x4f, 0x66, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x41, + 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, + 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x69, 0x65, 0x73, 0x18, 0x02, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, + 0x76, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x27, 0x4e, 0x6f, 0x74, 0x69, 0x66, 0x79, + 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x4f, 0x66, 0x43, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x41, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x16, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x22, 0x50, 0x0a, 0x17, 0x52, 0x65, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x43, 0x6f, + 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, + 0x0a, 0x0b, 0x49, 0x6e, 0x63, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x0b, 0x49, 0x6e, 0x63, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x61, 0x6c, 0x88, 0x01, 0x01, 0x42, 0x0e, 0x0a, 0x0c, 0x5f, 0x49, 0x6e, 0x63, 0x72, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x22, 0x46, 0x0a, 0x18, 0x52, 0x65, 0x62, 0x75, 0x69, 0x6c, 0x64, + 0x43, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x2a, 0x0a, 0x10, 0x52, 0x65, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x43, 0x6f, 0x6e, 0x74, + 0x61, 0x69, 0x6e, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x52, 0x65, 0x62, + 0x75, 0x69, 0x6c, 0x64, 0x43, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x32, 0xf5, 0x02, + 0x0a, 0x0d, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x48, 0x6f, 0x73, 0x74, 0x12, + 0xc4, 0x01, 0x0a, 0x1f, 0x4e, 0x6f, 0x74, 0x69, 0x66, 0x79, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, + 0x61, 0x63, 0x65, 0x4f, 0x66, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x76, + 0x69, 0x74, 0x79, 0x12, 0x4f, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x73, + 0x2e, 0x47, 0x72, 0x70, 0x63, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x48, + 0x6f, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x6f, + 0x74, 0x69, 0x66, 0x79, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x4f, 0x66, 0x43, + 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x50, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, + 0x73, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, + 0x48, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4e, + 0x6f, 0x74, 0x69, 0x66, 0x79, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x4f, 0x66, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x9c, 0x01, 0x0a, 0x15, 0x52, 0x65, 0x62, 0x75, 0x69, + 0x6c, 0x64, 0x43, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x41, 0x73, 0x79, 0x6e, 0x63, + 0x12, 0x40, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x73, 0x2e, 0x47, 0x72, + 0x70, 0x63, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x48, 0x6f, 0x73, 0x74, + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x62, 0x75, 0x69, + 0x6c, 0x64, 0x43, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x41, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x73, 0x2e, + 0x47, 0x72, 0x70, 0x63, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x48, 0x6f, + 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x62, + 0x75, 0x69, 0x6c, 0x64, 0x43, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x0d, 0x5a, 0x0b, 0x2e, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x73, + 0x70, 0x61, 0x63, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -159,16 +294,20 @@ func file_codespace_codespace_host_service_v1_proto_rawDescGZIP() []byte { return file_codespace_codespace_host_service_v1_proto_rawDescData } -var file_codespace_codespace_host_service_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_codespace_codespace_host_service_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 4) var file_codespace_codespace_host_service_v1_proto_goTypes = []interface{}{ - (*RebuildContainerRequest)(nil), // 0: Codespaces.Grpc.CodespaceHostService.v1.RebuildContainerRequest - (*RebuildContainerResponse)(nil), // 1: Codespaces.Grpc.CodespaceHostService.v1.RebuildContainerResponse + (*NotifyCodespaceOfClientActivityRequest)(nil), // 0: Codespaces.Grpc.CodespaceHostService.v1.NotifyCodespaceOfClientActivityRequest + (*NotifyCodespaceOfClientActivityResponse)(nil), // 1: Codespaces.Grpc.CodespaceHostService.v1.NotifyCodespaceOfClientActivityResponse + (*RebuildContainerRequest)(nil), // 2: Codespaces.Grpc.CodespaceHostService.v1.RebuildContainerRequest + (*RebuildContainerResponse)(nil), // 3: Codespaces.Grpc.CodespaceHostService.v1.RebuildContainerResponse } var file_codespace_codespace_host_service_v1_proto_depIdxs = []int32{ - 0, // 0: Codespaces.Grpc.CodespaceHostService.v1.CodespaceHost.RebuildContainerAsync:input_type -> Codespaces.Grpc.CodespaceHostService.v1.RebuildContainerRequest - 1, // 1: Codespaces.Grpc.CodespaceHostService.v1.CodespaceHost.RebuildContainerAsync:output_type -> Codespaces.Grpc.CodespaceHostService.v1.RebuildContainerResponse - 1, // [1:2] is the sub-list for method output_type - 0, // [0:1] is the sub-list for method input_type + 0, // 0: Codespaces.Grpc.CodespaceHostService.v1.CodespaceHost.NotifyCodespaceOfClientActivity:input_type -> Codespaces.Grpc.CodespaceHostService.v1.NotifyCodespaceOfClientActivityRequest + 2, // 1: Codespaces.Grpc.CodespaceHostService.v1.CodespaceHost.RebuildContainerAsync:input_type -> Codespaces.Grpc.CodespaceHostService.v1.RebuildContainerRequest + 1, // 2: Codespaces.Grpc.CodespaceHostService.v1.CodespaceHost.NotifyCodespaceOfClientActivity:output_type -> Codespaces.Grpc.CodespaceHostService.v1.NotifyCodespaceOfClientActivityResponse + 3, // 3: Codespaces.Grpc.CodespaceHostService.v1.CodespaceHost.RebuildContainerAsync:output_type -> Codespaces.Grpc.CodespaceHostService.v1.RebuildContainerResponse + 2, // [2:4] is the sub-list for method output_type + 0, // [0:2] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -181,7 +320,7 @@ func file_codespace_codespace_host_service_v1_proto_init() { } if !protoimpl.UnsafeEnabled { file_codespace_codespace_host_service_v1_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RebuildContainerRequest); i { + switch v := v.(*NotifyCodespaceOfClientActivityRequest); i { case 0: return &v.state case 1: @@ -193,6 +332,30 @@ func file_codespace_codespace_host_service_v1_proto_init() { } } file_codespace_codespace_host_service_v1_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*NotifyCodespaceOfClientActivityResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_codespace_codespace_host_service_v1_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RebuildContainerRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_codespace_codespace_host_service_v1_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*RebuildContainerResponse); i { case 0: return &v.state @@ -205,14 +368,14 @@ func file_codespace_codespace_host_service_v1_proto_init() { } } } - file_codespace_codespace_host_service_v1_proto_msgTypes[0].OneofWrappers = []interface{}{} + file_codespace_codespace_host_service_v1_proto_msgTypes[2].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_codespace_codespace_host_service_v1_proto_rawDesc, NumEnums: 0, - NumMessages: 2, + NumMessages: 4, NumExtensions: 0, NumServices: 1, }, diff --git a/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto b/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto index 40078d367..b2cc92949 100644 --- a/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto +++ b/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto @@ -5,9 +5,19 @@ option go_package = "./codespace"; package Codespaces.Grpc.CodespaceHostService.v1; service CodespaceHost { + rpc NotifyCodespaceOfClientActivity (NotifyCodespaceOfClientActivityRequest) returns (NotifyCodespaceOfClientActivityResponse); rpc RebuildContainerAsync (RebuildContainerRequest) returns (RebuildContainerResponse); } +message NotifyCodespaceOfClientActivityRequest { + string ClientId = 1; + repeated string ClientActivities = 2; +} +message NotifyCodespaceOfClientActivityResponse { + bool Result = 1; + string Message = 2; +} + message RebuildContainerRequest { optional bool Incremental = 1; } diff --git a/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto.mock.go b/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto.mock.go new file mode 100644 index 000000000..246849fe0 --- /dev/null +++ b/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto.mock.go @@ -0,0 +1,168 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package codespace + +import ( + context "context" + sync "sync" +) + +// Ensure, that CodespaceHostServerMock does implement CodespaceHostServer. +// If this is not the case, regenerate this file with moq. +var _ CodespaceHostServer = &CodespaceHostServerMock{} + +// CodespaceHostServerMock is a mock implementation of CodespaceHostServer. +// +// func TestSomethingThatUsesCodespaceHostServer(t *testing.T) { +// +// // make and configure a mocked CodespaceHostServer +// mockedCodespaceHostServer := &CodespaceHostServerMock{ +// NotifyCodespaceOfClientActivityFunc: func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) { +// panic("mock out the NotifyCodespaceOfClientActivity method") +// }, +// RebuildContainerAsyncFunc: func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) { +// panic("mock out the RebuildContainerAsync method") +// }, +// mustEmbedUnimplementedCodespaceHostServerFunc: func() { +// panic("mock out the mustEmbedUnimplementedCodespaceHostServer method") +// }, +// } +// +// // use mockedCodespaceHostServer in code that requires CodespaceHostServer +// // and then make assertions. +// +// } +type CodespaceHostServerMock struct { + // NotifyCodespaceOfClientActivityFunc mocks the NotifyCodespaceOfClientActivity method. + NotifyCodespaceOfClientActivityFunc func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) + + // RebuildContainerAsyncFunc mocks the RebuildContainerAsync method. + RebuildContainerAsyncFunc func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) + + // mustEmbedUnimplementedCodespaceHostServerFunc mocks the mustEmbedUnimplementedCodespaceHostServer method. + mustEmbedUnimplementedCodespaceHostServerFunc func() + + // calls tracks calls to the methods. + calls struct { + // NotifyCodespaceOfClientActivity holds details about calls to the NotifyCodespaceOfClientActivity method. + NotifyCodespaceOfClientActivity []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // NotifyCodespaceOfClientActivityRequest is the notifyCodespaceOfClientActivityRequest argument value. + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest + } + // RebuildContainerAsync holds details about calls to the RebuildContainerAsync method. + RebuildContainerAsync []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // RebuildContainerRequest is the rebuildContainerRequest argument value. + RebuildContainerRequest *RebuildContainerRequest + } + // mustEmbedUnimplementedCodespaceHostServer holds details about calls to the mustEmbedUnimplementedCodespaceHostServer method. + mustEmbedUnimplementedCodespaceHostServer []struct { + } + } + lockNotifyCodespaceOfClientActivity sync.RWMutex + lockRebuildContainerAsync sync.RWMutex + lockmustEmbedUnimplementedCodespaceHostServer sync.RWMutex +} + +// NotifyCodespaceOfClientActivity calls NotifyCodespaceOfClientActivityFunc. +func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivity(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) { + if mock.NotifyCodespaceOfClientActivityFunc == nil { + panic("CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc: method is nil but CodespaceHostServer.NotifyCodespaceOfClientActivity was just called") + } + callInfo := struct { + ContextMoqParam context.Context + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest + }{ + ContextMoqParam: contextMoqParam, + NotifyCodespaceOfClientActivityRequest: notifyCodespaceOfClientActivityRequest, + } + mock.lockNotifyCodespaceOfClientActivity.Lock() + mock.calls.NotifyCodespaceOfClientActivity = append(mock.calls.NotifyCodespaceOfClientActivity, callInfo) + mock.lockNotifyCodespaceOfClientActivity.Unlock() + return mock.NotifyCodespaceOfClientActivityFunc(contextMoqParam, notifyCodespaceOfClientActivityRequest) +} + +// NotifyCodespaceOfClientActivityCalls gets all the calls that were made to NotifyCodespaceOfClientActivity. +// Check the length with: +// +// len(mockedCodespaceHostServer.NotifyCodespaceOfClientActivityCalls()) +func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivityCalls() []struct { + ContextMoqParam context.Context + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest +} { + var calls []struct { + ContextMoqParam context.Context + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest + } + mock.lockNotifyCodespaceOfClientActivity.RLock() + calls = mock.calls.NotifyCodespaceOfClientActivity + mock.lockNotifyCodespaceOfClientActivity.RUnlock() + return calls +} + +// RebuildContainerAsync calls RebuildContainerAsyncFunc. +func (mock *CodespaceHostServerMock) RebuildContainerAsync(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) { + if mock.RebuildContainerAsyncFunc == nil { + panic("CodespaceHostServerMock.RebuildContainerAsyncFunc: method is nil but CodespaceHostServer.RebuildContainerAsync was just called") + } + callInfo := struct { + ContextMoqParam context.Context + RebuildContainerRequest *RebuildContainerRequest + }{ + ContextMoqParam: contextMoqParam, + RebuildContainerRequest: rebuildContainerRequest, + } + mock.lockRebuildContainerAsync.Lock() + mock.calls.RebuildContainerAsync = append(mock.calls.RebuildContainerAsync, callInfo) + mock.lockRebuildContainerAsync.Unlock() + return mock.RebuildContainerAsyncFunc(contextMoqParam, rebuildContainerRequest) +} + +// RebuildContainerAsyncCalls gets all the calls that were made to RebuildContainerAsync. +// Check the length with: +// +// len(mockedCodespaceHostServer.RebuildContainerAsyncCalls()) +func (mock *CodespaceHostServerMock) RebuildContainerAsyncCalls() []struct { + ContextMoqParam context.Context + RebuildContainerRequest *RebuildContainerRequest +} { + var calls []struct { + ContextMoqParam context.Context + RebuildContainerRequest *RebuildContainerRequest + } + mock.lockRebuildContainerAsync.RLock() + calls = mock.calls.RebuildContainerAsync + mock.lockRebuildContainerAsync.RUnlock() + return calls +} + +// mustEmbedUnimplementedCodespaceHostServer calls mustEmbedUnimplementedCodespaceHostServerFunc. +func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServer() { + if mock.mustEmbedUnimplementedCodespaceHostServerFunc == nil { + panic("CodespaceHostServerMock.mustEmbedUnimplementedCodespaceHostServerFunc: method is nil but CodespaceHostServer.mustEmbedUnimplementedCodespaceHostServer was just called") + } + callInfo := struct { + }{} + mock.lockmustEmbedUnimplementedCodespaceHostServer.Lock() + mock.calls.mustEmbedUnimplementedCodespaceHostServer = append(mock.calls.mustEmbedUnimplementedCodespaceHostServer, callInfo) + mock.lockmustEmbedUnimplementedCodespaceHostServer.Unlock() + mock.mustEmbedUnimplementedCodespaceHostServerFunc() +} + +// mustEmbedUnimplementedCodespaceHostServerCalls gets all the calls that were made to mustEmbedUnimplementedCodespaceHostServer. +// Check the length with: +// +// len(mockedCodespaceHostServer.mustEmbedUnimplementedCodespaceHostServerCalls()) +func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServerCalls() []struct { +} { + var calls []struct { + } + mock.lockmustEmbedUnimplementedCodespaceHostServer.RLock() + calls = mock.calls.mustEmbedUnimplementedCodespaceHostServer + mock.lockmustEmbedUnimplementedCodespaceHostServer.RUnlock() + return calls +} diff --git a/internal/codespaces/rpc/codespace/codespace_host_service.v1_grpc.pb.go b/internal/codespaces/rpc/codespace/codespace_host_service.v1_grpc.pb.go index 0be6bdc58..e876578ad 100644 --- a/internal/codespaces/rpc/codespace/codespace_host_service.v1_grpc.pb.go +++ b/internal/codespaces/rpc/codespace/codespace_host_service.v1_grpc.pb.go @@ -22,6 +22,7 @@ const _ = grpc.SupportPackageIsVersion7 // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type CodespaceHostClient interface { + NotifyCodespaceOfClientActivity(ctx context.Context, in *NotifyCodespaceOfClientActivityRequest, opts ...grpc.CallOption) (*NotifyCodespaceOfClientActivityResponse, error) RebuildContainerAsync(ctx context.Context, in *RebuildContainerRequest, opts ...grpc.CallOption) (*RebuildContainerResponse, error) } @@ -33,6 +34,15 @@ func NewCodespaceHostClient(cc grpc.ClientConnInterface) CodespaceHostClient { return &codespaceHostClient{cc} } +func (c *codespaceHostClient) NotifyCodespaceOfClientActivity(ctx context.Context, in *NotifyCodespaceOfClientActivityRequest, opts ...grpc.CallOption) (*NotifyCodespaceOfClientActivityResponse, error) { + out := new(NotifyCodespaceOfClientActivityResponse) + err := c.cc.Invoke(ctx, "/Codespaces.Grpc.CodespaceHostService.v1.CodespaceHost/NotifyCodespaceOfClientActivity", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *codespaceHostClient) RebuildContainerAsync(ctx context.Context, in *RebuildContainerRequest, opts ...grpc.CallOption) (*RebuildContainerResponse, error) { out := new(RebuildContainerResponse) err := c.cc.Invoke(ctx, "/Codespaces.Grpc.CodespaceHostService.v1.CodespaceHost/RebuildContainerAsync", in, out, opts...) @@ -46,6 +56,7 @@ func (c *codespaceHostClient) RebuildContainerAsync(ctx context.Context, in *Reb // All implementations must embed UnimplementedCodespaceHostServer // for forward compatibility type CodespaceHostServer interface { + NotifyCodespaceOfClientActivity(context.Context, *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) RebuildContainerAsync(context.Context, *RebuildContainerRequest) (*RebuildContainerResponse, error) mustEmbedUnimplementedCodespaceHostServer() } @@ -54,6 +65,9 @@ type CodespaceHostServer interface { type UnimplementedCodespaceHostServer struct { } +func (UnimplementedCodespaceHostServer) NotifyCodespaceOfClientActivity(context.Context, *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method NotifyCodespaceOfClientActivity not implemented") +} func (UnimplementedCodespaceHostServer) RebuildContainerAsync(context.Context, *RebuildContainerRequest) (*RebuildContainerResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method RebuildContainerAsync not implemented") } @@ -70,6 +84,24 @@ func RegisterCodespaceHostServer(s grpc.ServiceRegistrar, srv CodespaceHostServe s.RegisterService(&CodespaceHost_ServiceDesc, srv) } +func _CodespaceHost_NotifyCodespaceOfClientActivity_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(NotifyCodespaceOfClientActivityRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(CodespaceHostServer).NotifyCodespaceOfClientActivity(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/Codespaces.Grpc.CodespaceHostService.v1.CodespaceHost/NotifyCodespaceOfClientActivity", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(CodespaceHostServer).NotifyCodespaceOfClientActivity(ctx, req.(*NotifyCodespaceOfClientActivityRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _CodespaceHost_RebuildContainerAsync_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(RebuildContainerRequest) if err := dec(in); err != nil { @@ -95,6 +127,10 @@ var CodespaceHost_ServiceDesc = grpc.ServiceDesc{ ServiceName: "Codespaces.Grpc.CodespaceHostService.v1.CodespaceHost", HandlerType: (*CodespaceHostServer)(nil), Methods: []grpc.MethodDesc{ + { + MethodName: "NotifyCodespaceOfClientActivity", + Handler: _CodespaceHost_NotifyCodespaceOfClientActivity_Handler, + }, { MethodName: "RebuildContainerAsync", Handler: _CodespaceHost_RebuildContainerAsync_Handler, diff --git a/internal/codespaces/rpc/generate.md b/internal/codespaces/rpc/generate.md index 7ae1dcc1a..d0d6bbc9d 100644 --- a/internal/codespaces/rpc/generate.md +++ b/internal/codespaces/rpc/generate.md @@ -6,7 +6,8 @@ Instructions for generating and adding gRPC protocol buffers. 1. [Download `protoc`](https://grpc.io/docs/protoc-installation/) 2. [Download protocol compiler plugins for Go](https://grpc.io/docs/languages/go/quickstart/) -3. Run `./generate.sh` from the `internal/codespaces/grpc` directory +3. Install moq: `go install github.com/matryer/moq@latest` +4. Run `./generate.sh` from the `internal/codespaces/rpc` directory ## Add New Protocol Buffers diff --git a/internal/codespaces/rpc/generate.sh b/internal/codespaces/rpc/generate.sh index 159803bbe..4ba2f898a 100755 --- a/internal/codespaces/rpc/generate.sh +++ b/internal/codespaces/rpc/generate.sh @@ -15,14 +15,21 @@ if ! protoc-gen-go-grpc --version; then fi function generate { - local contract="$1" + local dir="$1" + local proto="$2" + + local contract="$dir/$proto" protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative "$contract" echo "Generated protocol buffers for $contract" + + services=$(cat "$contract" | grep -Eo "service .+ {" | awk '{print $2 "Server"}') + moq -out $contract.mock.go $dir $services + echo "Generated mock protocols for $contract" } -generate jupyter/jupyter_server_host_service.v1.proto -generate codespace/codespace_host_service.v1.proto -generate ssh/ssh_server_host_service.v1.proto +generate jupyter jupyter_server_host_service.v1.proto +generate codespace codespace_host_service.v1.proto +generate ssh ssh_server_host_service.v1.proto echo 'Done!' diff --git a/internal/codespaces/rpc/invoker.go b/internal/codespaces/rpc/invoker.go index 67a88bb2f..bb2e25a55 100644 --- a/internal/codespaces/rpc/invoker.go +++ b/internal/codespaces/rpc/invoker.go @@ -29,6 +29,8 @@ const ( const ( codespacesInternalPort = 16634 codespacesInternalSessionName = "CodespacesInternal" + clientName = "gh" + connectedEventName = "connected" ) type StartSSHServerOptions struct { @@ -68,11 +70,11 @@ func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession) (Inv // Finds a free port to listen on and creates a new RPC invoker that connects to that port func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, error) { - listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0)) + listener, err := listenTCP() if err != nil { - return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err) + return nil, err } - localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port) + localAddress := listener.Addr().String() invoker := &invoker{ session: session, @@ -128,6 +130,12 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, invoker.codespaceClient = codespace.NewCodespaceHostClient(conn) invoker.sshClient = ssh.NewSshServerHostClient(conn) + // Send initial connection heartbeat (no need to throw if we fail to get a response from the server) + _ = invoker.notifyCodespaceOfClientActivity(ctx, connectedEventName) + + // Start the activity heatbeats + go invoker.heartbeat(pfctx, 1*time.Minute) + return invoker, nil } @@ -229,3 +237,45 @@ func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options StartSS return port, response.User, nil } + +func listenTCP() (*net.TCPListener, error) { + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("failed to build tcp address: %w", err) + } + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err) + } + + return listener, nil +} + +// Periodically check whether there is a reason to keep the connection alive, and if so, notify the codespace to do so +func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + reason := i.session.GetKeepAliveReason() + _ = i.notifyCodespaceOfClientActivity(ctx, reason) + } + } +} + +func (i *invoker) notifyCodespaceOfClientActivity(ctx context.Context, activity string) error { + ctx = i.appendMetadata(ctx) + ctx, cancel := context.WithTimeout(ctx, requestTimeout) + defer cancel() + + _, err := i.codespaceClient.NotifyCodespaceOfClientActivity(ctx, &codespace.NotifyCodespaceOfClientActivityRequest{ClientId: clientName, ClientActivities: []string{activity}}) + if err != nil { + return fmt.Errorf("failed to invoke notify RPC: %w", err) + } + + return nil +} diff --git a/internal/codespaces/rpc/invoker_test.go b/internal/codespaces/rpc/invoker_test.go index bfed27181..ba3e13ac3 100644 --- a/internal/codespaces/rpc/invoker_test.go +++ b/internal/codespaces/rpc/invoker_test.go @@ -3,73 +3,159 @@ package rpc import ( "context" "fmt" - "log" - "os" + "net" + "strconv" "testing" + "github.com/cli/cli/v2/internal/codespaces/rpc/codespace" + "github.com/cli/cli/v2/internal/codespaces/rpc/jupyter" + "github.com/cli/cli/v2/internal/codespaces/rpc/ssh" rpctest "github.com/cli/cli/v2/internal/codespaces/rpc/test" + "google.golang.org/grpc" ) -func startServer(t *testing.T) { - t.Helper() - if os.Getenv("GITHUB_ACTIONS") == "true" { - t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5663") +type mockServer struct { + jupyter.JupyterServerHostServerMock + codespace.CodespaceHostServerMock + ssh.SshServerHostServerMock +} + +func newMockServer() *mockServer { + server := &mockServer{} + + server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc = func(context.Context, *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) { + return &codespace.NotifyCodespaceOfClientActivityResponse{ + Message: "", + Result: true, + }, nil + } + + return server +} + +// runTestGrpcServer serves grpc requests over the provided Listener using the mockServer for mocked callbacks. +// It does not return until the Context is cancelled and the server fully shuts down. +func runTestGrpcServer(ctx context.Context, listener net.Listener, server *mockServer) error { + s := grpc.NewServer() + jupyter.RegisterJupyterServerHostServer(s, server) + codespace.RegisterCodespaceHostServer(s, server) + ssh.RegisterSshServerHostServer(s, server) + + ch := make(chan error, 1) + go func() { ch <- s.Serve(listener) }() + + select { + case <-ctx.Done(): + s.Stop() + <-ch + return nil + case err := <-ch: + return err + } +} + +// createTestInvoker is the main test setup function. It returns an Invoker using the provided mockServer, as well as a shutdown function. +// The Invoker does not need to be closed directly, that will be handled by the shutdown function. +func createTestInvoker(t *testing.T, server *mockServer) (Invoker, func(), error) { + listener, err := net.Listen("tcp", "127.0.0.1:16634") + if err != nil { + return nil, nil, fmt.Errorf("failed to listen: %w", err) } ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan error) + go func() { ch <- runTestGrpcServer(ctx, listener, server) }() - // Start the gRPC server in the background - go func() { - err := rpctest.StartServer(ctx) - if err != nil && err != context.Canceled { - log.Println(fmt.Errorf("error starting test server: %v", err)) - } - }() - - // Stop the gRPC server when the test is done - t.Cleanup(func() { + close := func() { cancel() - }) -} - -func createTestInvoker(t *testing.T) Invoker { - t.Helper() + <-ch + listener.Close() + } invoker, err := CreateInvoker(context.Background(), &rpctest.Session{}) if err != nil { - t.Fatalf("error connecting to internal server: %v", err) + close() + return nil, nil, fmt.Errorf("error connecting to internal server: %w", err) } - t.Cleanup(func() { + return invoker, func() { invoker.Close() - }) + close() + }, nil +} - return invoker +// Test that the RPC invoker notifies the codespace of client activity on connection +func verifyNotifyCodespaceOfClientActivity(t *testing.T, server *mockServer) { + calls := server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityCalls() + if len(calls) == 0 { + t.Fatalf("no client activity calls") + } + + for _, call := range calls { + activities := call.NotifyCodespaceOfClientActivityRequest.GetClientActivities() + if activities[0] == connectedEventName { + return + } + } + + t.Fatalf("no activity named %s", connectedEventName) } // Test that the RPC invoker returns the correct port and URL when the JupyterLab server starts successfully func TestStartJupyterServerSuccess(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) + resp := jupyter.GetRunningServerResponse{ + Port: strconv.Itoa(1234), + ServerUrl: "http://localhost:1234?token=1234", + Message: "", + Result: true, + } + + server := newMockServer() + server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + port, url, err := invoker.StartJupyterServer(context.Background()) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } - if port != rpctest.JupyterPort { - t.Fatalf("expected %d, got %d", rpctest.JupyterPort, port) + if strconv.Itoa(port) != resp.Port { + t.Fatalf("expected %s, got %d", resp.Port, port) } - if url != rpctest.JupyterServerUrl { - t.Fatalf("expected %s, got %s", rpctest.JupyterServerUrl, url) + if url != resp.ServerUrl { + t.Fatalf("expected %s, got %s", resp.ServerUrl, url) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker returns an error when the JupyterLab server fails to start func TestStartJupyterServerFailure(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - rpctest.JupyterMessage = "error message" - rpctest.JupyterResult = false - errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", rpctest.JupyterMessage) + resp := jupyter.GetRunningServerResponse{ + Port: strconv.Itoa(1234), + ServerUrl: "http://localhost:1234?token=1234", + Message: "error message", + Result: false, + } + + server := newMockServer() + server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", resp.Message) port, url, err := invoker.StartJupyterServer(context.Background()) if err.Error() != errorMessage { t.Fatalf("expected %v, got %v", errorMessage, err) @@ -80,35 +166,79 @@ func TestStartJupyterServerFailure(t *testing.T) { if url != "" { t.Fatalf("expected %s, got %s", "", url) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker doesn't throw an error when requesting an incremental rebuild func TestRebuildContainerIncremental(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - err := invoker.RebuildContainer(context.Background(), false) + resp := codespace.RebuildContainerResponse{ + RebuildContainer: true, + } + + server := newMockServer() + server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + err = invoker.RebuildContainer(context.Background(), false) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker doesn't throw an error when requesting a full rebuild func TestRebuildContainerFull(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - err := invoker.RebuildContainer(context.Background(), true) + resp := codespace.RebuildContainerResponse{ + RebuildContainer: true, + } + + server := newMockServer() + server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + err = invoker.RebuildContainer(context.Background(), true) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker throws an error when the rebuild fails func TestRebuildContainerFailure(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - rpctest.RebuildContainer = false + resp := codespace.RebuildContainerResponse{ + RebuildContainer: false, + } + + server := newMockServer() + server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + errorMessage := "couldn't rebuild codespace" - err := invoker.RebuildContainer(context.Background(), true) + err = invoker.RebuildContainer(context.Background(), true) if err.Error() != errorMessage { t.Fatalf("expected %v, got %v", errorMessage, err) } @@ -116,27 +246,59 @@ func TestRebuildContainerFailure(t *testing.T) { // Test that the RPC invoker returns the correct port and user when the SSH server starts successfully func TestStartSSHServerSuccess(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) + resp := ssh.StartRemoteServerResponse{ + ServerPort: strconv.Itoa(1234), + User: "test", + Message: "", + Result: true, + } + + server := newMockServer() + server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + port, user, err := invoker.StartSSHServer(context.Background()) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } - if port != rpctest.SshServerPort { - t.Fatalf("expected %d, got %d", rpctest.SshServerPort, port) + if strconv.Itoa(port) != resp.ServerPort { + t.Fatalf("expected %s, got %d", resp.ServerPort, port) } - if user != rpctest.SshUser { - t.Fatalf("expected %s, got %s", rpctest.SshUser, user) + if user != resp.User { + t.Fatalf("expected %s, got %s", resp.User, user) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker returns an error when the SSH server fails to start func TestStartSSHServerFailure(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - rpctest.SshMessage = "error message" - rpctest.SshResult = false - errorMessage := fmt.Sprintf("failed to start SSH server: %s", rpctest.SshMessage) + resp := ssh.StartRemoteServerResponse{ + ServerPort: strconv.Itoa(1234), + User: "test", + Message: "error message", + Result: false, + } + + server := newMockServer() + server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + errorMessage := fmt.Sprintf("failed to start SSH server: %s", resp.Message) port, user, err := invoker.StartSSHServer(context.Background()) if err.Error() != errorMessage { t.Fatalf("expected %v, got %v", errorMessage, err) diff --git a/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go index 8e11c6a32..b8f400d3c 100644 --- a/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go +++ b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.0 +// protoc-gen-go v1.28.1 // protoc v3.21.12 // source: jupyter/jupyter_server_host_service.v1.proto diff --git a/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.proto.mock.go b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.proto.mock.go new file mode 100644 index 000000000..12ea0bb5b --- /dev/null +++ b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.proto.mock.go @@ -0,0 +1,118 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package jupyter + +import ( + context "context" + sync "sync" +) + +// Ensure, that JupyterServerHostServerMock does implement JupyterServerHostServer. +// If this is not the case, regenerate this file with moq. +var _ JupyterServerHostServer = &JupyterServerHostServerMock{} + +// JupyterServerHostServerMock is a mock implementation of JupyterServerHostServer. +// +// func TestSomethingThatUsesJupyterServerHostServer(t *testing.T) { +// +// // make and configure a mocked JupyterServerHostServer +// mockedJupyterServerHostServer := &JupyterServerHostServerMock{ +// GetRunningServerFunc: func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) { +// panic("mock out the GetRunningServer method") +// }, +// mustEmbedUnimplementedJupyterServerHostServerFunc: func() { +// panic("mock out the mustEmbedUnimplementedJupyterServerHostServer method") +// }, +// } +// +// // use mockedJupyterServerHostServer in code that requires JupyterServerHostServer +// // and then make assertions. +// +// } +type JupyterServerHostServerMock struct { + // GetRunningServerFunc mocks the GetRunningServer method. + GetRunningServerFunc func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) + + // mustEmbedUnimplementedJupyterServerHostServerFunc mocks the mustEmbedUnimplementedJupyterServerHostServer method. + mustEmbedUnimplementedJupyterServerHostServerFunc func() + + // calls tracks calls to the methods. + calls struct { + // GetRunningServer holds details about calls to the GetRunningServer method. + GetRunningServer []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // GetRunningServerRequest is the getRunningServerRequest argument value. + GetRunningServerRequest *GetRunningServerRequest + } + // mustEmbedUnimplementedJupyterServerHostServer holds details about calls to the mustEmbedUnimplementedJupyterServerHostServer method. + mustEmbedUnimplementedJupyterServerHostServer []struct { + } + } + lockGetRunningServer sync.RWMutex + lockmustEmbedUnimplementedJupyterServerHostServer sync.RWMutex +} + +// GetRunningServer calls GetRunningServerFunc. +func (mock *JupyterServerHostServerMock) GetRunningServer(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) { + if mock.GetRunningServerFunc == nil { + panic("JupyterServerHostServerMock.GetRunningServerFunc: method is nil but JupyterServerHostServer.GetRunningServer was just called") + } + callInfo := struct { + ContextMoqParam context.Context + GetRunningServerRequest *GetRunningServerRequest + }{ + ContextMoqParam: contextMoqParam, + GetRunningServerRequest: getRunningServerRequest, + } + mock.lockGetRunningServer.Lock() + mock.calls.GetRunningServer = append(mock.calls.GetRunningServer, callInfo) + mock.lockGetRunningServer.Unlock() + return mock.GetRunningServerFunc(contextMoqParam, getRunningServerRequest) +} + +// GetRunningServerCalls gets all the calls that were made to GetRunningServer. +// Check the length with: +// +// len(mockedJupyterServerHostServer.GetRunningServerCalls()) +func (mock *JupyterServerHostServerMock) GetRunningServerCalls() []struct { + ContextMoqParam context.Context + GetRunningServerRequest *GetRunningServerRequest +} { + var calls []struct { + ContextMoqParam context.Context + GetRunningServerRequest *GetRunningServerRequest + } + mock.lockGetRunningServer.RLock() + calls = mock.calls.GetRunningServer + mock.lockGetRunningServer.RUnlock() + return calls +} + +// mustEmbedUnimplementedJupyterServerHostServer calls mustEmbedUnimplementedJupyterServerHostServerFunc. +func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServer() { + if mock.mustEmbedUnimplementedJupyterServerHostServerFunc == nil { + panic("JupyterServerHostServerMock.mustEmbedUnimplementedJupyterServerHostServerFunc: method is nil but JupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServer was just called") + } + callInfo := struct { + }{} + mock.lockmustEmbedUnimplementedJupyterServerHostServer.Lock() + mock.calls.mustEmbedUnimplementedJupyterServerHostServer = append(mock.calls.mustEmbedUnimplementedJupyterServerHostServer, callInfo) + mock.lockmustEmbedUnimplementedJupyterServerHostServer.Unlock() + mock.mustEmbedUnimplementedJupyterServerHostServerFunc() +} + +// mustEmbedUnimplementedJupyterServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedJupyterServerHostServer. +// Check the length with: +// +// len(mockedJupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServerCalls()) +func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServerCalls() []struct { +} { + var calls []struct { + } + mock.lockmustEmbedUnimplementedJupyterServerHostServer.RLock() + calls = mock.calls.mustEmbedUnimplementedJupyterServerHostServer + mock.lockmustEmbedUnimplementedJupyterServerHostServer.RUnlock() + return calls +} diff --git a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go index c495eb781..3dd22f583 100644 --- a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go +++ b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.0 +// protoc-gen-go v1.28.1 // protoc v3.21.12 // source: ssh/ssh_server_host_service.v1.proto diff --git a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto.mock.go b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto.mock.go new file mode 100644 index 000000000..d11e99461 --- /dev/null +++ b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto.mock.go @@ -0,0 +1,118 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package ssh + +import ( + context "context" + sync "sync" +) + +// Ensure, that SshServerHostServerMock does implement SshServerHostServer. +// If this is not the case, regenerate this file with moq. +var _ SshServerHostServer = &SshServerHostServerMock{} + +// SshServerHostServerMock is a mock implementation of SshServerHostServer. +// +// func TestSomethingThatUsesSshServerHostServer(t *testing.T) { +// +// // make and configure a mocked SshServerHostServer +// mockedSshServerHostServer := &SshServerHostServerMock{ +// StartRemoteServerAsyncFunc: func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) { +// panic("mock out the StartRemoteServerAsync method") +// }, +// mustEmbedUnimplementedSshServerHostServerFunc: func() { +// panic("mock out the mustEmbedUnimplementedSshServerHostServer method") +// }, +// } +// +// // use mockedSshServerHostServer in code that requires SshServerHostServer +// // and then make assertions. +// +// } +type SshServerHostServerMock struct { + // StartRemoteServerAsyncFunc mocks the StartRemoteServerAsync method. + StartRemoteServerAsyncFunc func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) + + // mustEmbedUnimplementedSshServerHostServerFunc mocks the mustEmbedUnimplementedSshServerHostServer method. + mustEmbedUnimplementedSshServerHostServerFunc func() + + // calls tracks calls to the methods. + calls struct { + // StartRemoteServerAsync holds details about calls to the StartRemoteServerAsync method. + StartRemoteServerAsync []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // StartRemoteServerRequest is the startRemoteServerRequest argument value. + StartRemoteServerRequest *StartRemoteServerRequest + } + // mustEmbedUnimplementedSshServerHostServer holds details about calls to the mustEmbedUnimplementedSshServerHostServer method. + mustEmbedUnimplementedSshServerHostServer []struct { + } + } + lockStartRemoteServerAsync sync.RWMutex + lockmustEmbedUnimplementedSshServerHostServer sync.RWMutex +} + +// StartRemoteServerAsync calls StartRemoteServerAsyncFunc. +func (mock *SshServerHostServerMock) StartRemoteServerAsync(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) { + if mock.StartRemoteServerAsyncFunc == nil { + panic("SshServerHostServerMock.StartRemoteServerAsyncFunc: method is nil but SshServerHostServer.StartRemoteServerAsync was just called") + } + callInfo := struct { + ContextMoqParam context.Context + StartRemoteServerRequest *StartRemoteServerRequest + }{ + ContextMoqParam: contextMoqParam, + StartRemoteServerRequest: startRemoteServerRequest, + } + mock.lockStartRemoteServerAsync.Lock() + mock.calls.StartRemoteServerAsync = append(mock.calls.StartRemoteServerAsync, callInfo) + mock.lockStartRemoteServerAsync.Unlock() + return mock.StartRemoteServerAsyncFunc(contextMoqParam, startRemoteServerRequest) +} + +// StartRemoteServerAsyncCalls gets all the calls that were made to StartRemoteServerAsync. +// Check the length with: +// +// len(mockedSshServerHostServer.StartRemoteServerAsyncCalls()) +func (mock *SshServerHostServerMock) StartRemoteServerAsyncCalls() []struct { + ContextMoqParam context.Context + StartRemoteServerRequest *StartRemoteServerRequest +} { + var calls []struct { + ContextMoqParam context.Context + StartRemoteServerRequest *StartRemoteServerRequest + } + mock.lockStartRemoteServerAsync.RLock() + calls = mock.calls.StartRemoteServerAsync + mock.lockStartRemoteServerAsync.RUnlock() + return calls +} + +// mustEmbedUnimplementedSshServerHostServer calls mustEmbedUnimplementedSshServerHostServerFunc. +func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServer() { + if mock.mustEmbedUnimplementedSshServerHostServerFunc == nil { + panic("SshServerHostServerMock.mustEmbedUnimplementedSshServerHostServerFunc: method is nil but SshServerHostServer.mustEmbedUnimplementedSshServerHostServer was just called") + } + callInfo := struct { + }{} + mock.lockmustEmbedUnimplementedSshServerHostServer.Lock() + mock.calls.mustEmbedUnimplementedSshServerHostServer = append(mock.calls.mustEmbedUnimplementedSshServerHostServer, callInfo) + mock.lockmustEmbedUnimplementedSshServerHostServer.Unlock() + mock.mustEmbedUnimplementedSshServerHostServerFunc() +} + +// mustEmbedUnimplementedSshServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedSshServerHostServer. +// Check the length with: +// +// len(mockedSshServerHostServer.mustEmbedUnimplementedSshServerHostServerCalls()) +func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServerCalls() []struct { +} { + var calls []struct { + } + mock.lockmustEmbedUnimplementedSshServerHostServer.RLock() + calls = mock.calls.mustEmbedUnimplementedSshServerHostServer + mock.lockmustEmbedUnimplementedSshServerHostServer.RUnlock() + return calls +} diff --git a/internal/codespaces/rpc/test/server.go b/internal/codespaces/rpc/test/server.go deleted file mode 100644 index d2dc9f590..000000000 --- a/internal/codespaces/rpc/test/server.go +++ /dev/null @@ -1,97 +0,0 @@ -package test - -import ( - "context" - "fmt" - "net" - "strconv" - - "github.com/cli/cli/v2/internal/codespaces/rpc/codespace" - "github.com/cli/cli/v2/internal/codespaces/rpc/jupyter" - "github.com/cli/cli/v2/internal/codespaces/rpc/ssh" - "google.golang.org/grpc" -) - -const ( - ServerPort = 50051 -) - -// Mock responses for the `GetRunningServer` RPC method -var ( - JupyterPort = 1234 - JupyterServerUrl = "http://localhost:1234?token=1234" - JupyterMessage = "" - JupyterResult = true -) - -// Mock responses for the `RebuildContainerAsync` RPC method -var ( - RebuildContainer = true -) - -// Mock responses for the `StartRemoteServerAsync` RPC method -var ( - SshServerPort = 1234 - SshUser = "test" - SshMessage = "" - SshResult = true -) - -type server struct { - jupyter.UnimplementedJupyterServerHostServer - codespace.CodespaceHostServer - ssh.SshServerHostServer -} - -func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) { - return &jupyter.GetRunningServerResponse{ - Port: strconv.Itoa(JupyterPort), - ServerUrl: JupyterServerUrl, - Message: JupyterMessage, - Result: JupyterResult, - }, nil -} - -func (s *server) RebuildContainerAsync(ctx context.Context, in *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { - return &codespace.RebuildContainerResponse{ - RebuildContainer: RebuildContainer, - }, nil -} - -func (s *server) StartRemoteServerAsync(ctx context.Context, in *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) { - return &ssh.StartRemoteServerResponse{ - ServerPort: strconv.Itoa(SshServerPort), - User: SshUser, - Message: SshMessage, - Result: SshResult, - }, nil -} - -// Starts the mock gRPC server listening on port 50051 -func StartServer(ctx context.Context) error { - listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) - if err != nil { - return fmt.Errorf("failed to listen: %v", err) - } - defer listener.Close() - - s := grpc.NewServer() - jupyter.RegisterJupyterServerHostServer(s, &server{}) - codespace.RegisterCodespaceHostServer(s, &server{}) - ssh.RegisterSshServerHostServer(s, &server{}) - - ch := make(chan error, 1) - go func() { - if err := s.Serve(listener); err != nil { - ch <- fmt.Errorf("failed to serve: %v", err) - } - }() - - select { - case <-ctx.Done(): - s.Stop() - return ctx.Err() - case err := <-ch: - return err - } -} diff --git a/internal/codespaces/rpc/test/session.go b/internal/codespaces/rpc/test/session.go index 89d66a912..531d4c33f 100644 --- a/internal/codespaces/rpc/test/session.go +++ b/internal/codespaces/rpc/test/session.go @@ -24,8 +24,12 @@ func (*Session) GetSharedServers(context.Context) ([]*liveshare.Port, error) { func (s *Session) KeepAlive(reason string) { } +func (s *Session) GetKeepAliveReason() string { + return "" +} + func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) { - conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) if err != nil { return liveshare.ChannelID{}, err } diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 58be127be..9874d1a62 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "log" - "net" "time" "github.com/cli/cli/v2/internal/codespaces/api" @@ -53,11 +52,10 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl }() // Ensure local port is listening before client (getPostCreateOutput) connects. - listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port + listen, localPort, err := ListenTCP(0) if err != nil { return err } - localPort := listen.Addr().(*net.TCPAddr).Port progress.StartProgressIndicatorWithLabel("Fetching SSH Details") invoker, err := rpc.CreateInvoker(ctx, session) diff --git a/internal/run/stub.go b/internal/run/stub.go index 13a506c11..49bf62d29 100644 --- a/internal/run/stub.go +++ b/internal/run/stub.go @@ -106,18 +106,36 @@ type commandStub struct { callbacks []CommandCallback } +type errWithExitCode struct { + message string + exitCode int +} + +func (e errWithExitCode) Error() string { + return e.message +} + +func (e errWithExitCode) ExitCode() int { + return e.exitCode +} + // Run satisfies Runnable func (s *commandStub) Run() error { if s.exitStatus != 0 { - return fmt.Errorf("%s exited with status %d", s.pattern, s.exitStatus) + // It's nontrivial to construct a fake `exec.ExitError` instance, so we return an error type + // that has the `ExitCode() int` method. + return errWithExitCode{ + message: fmt.Sprintf("%s exited with status %d", s.pattern, s.exitStatus), + exitCode: s.exitStatus, + } } return nil } // Output satisfies Runnable func (s *commandStub) Output() ([]byte, error) { - if s.exitStatus != 0 { - return []byte(nil), fmt.Errorf("%s exited with status %d", s.pattern, s.exitStatus) + if err := s.Run(); err != nil { + return []byte(nil), err } return []byte(s.stdout), nil } diff --git a/internal/tableprinter/table_printer.go b/internal/tableprinter/table_printer.go index 2e8d398eb..059203a6e 100644 --- a/internal/tableprinter/table_printer.go +++ b/internal/tableprinter/table_printer.go @@ -24,11 +24,12 @@ func (t *TablePrinter) HeaderRow(columns ...string) { t.EndRow() } -func (tp *TablePrinter) AddTimeField(t time.Time, c func(string) string) { +// In tty mode display the fuzzy time difference between now and t. +// In nontty mode just display t with the time.RFC3339 format. +func (tp *TablePrinter) AddTimeField(now, t time.Time, c func(string) string) { tf := t.Format(time.RFC3339) if tp.isTTY { - // TODO: use a static time.Now - tf = text.FuzzyAgo(time.Now(), t) + tf = text.FuzzyAgo(now, t) } tp.AddField(tf, tableprinter.WithColor(c)) } diff --git a/internal/update/update.go b/internal/update/update.go index e9ada22f6..6d69eeada 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -1,7 +1,11 @@ package update import ( + "context" + "encoding/json" "fmt" + "io" + "net/http" "os" "path/filepath" "regexp" @@ -9,8 +13,6 @@ import ( "strings" "time" - "github.com/cli/cli/v2/api" - "github.com/cli/cli/v2/internal/ghinstance" "github.com/hashicorp/go-version" "gopkg.in/yaml.v3" ) @@ -30,13 +32,13 @@ type StateEntry struct { } // CheckForUpdate checks whether this software has had a newer release on GitHub -func CheckForUpdate(client *api.Client, stateFilePath, repo, currentVersion string) (*ReleaseInfo, error) { +func CheckForUpdate(ctx context.Context, client *http.Client, stateFilePath, repo, currentVersion string) (*ReleaseInfo, error) { stateEntry, _ := getStateEntry(stateFilePath) if stateEntry != nil && time.Since(stateEntry.CheckedForUpdateAt).Hours() < 24 { return nil, nil } - releaseInfo, err := getLatestReleaseInfo(client, repo) + releaseInfo, err := getLatestReleaseInfo(ctx, client, repo) if err != nil { return nil, err } @@ -53,13 +55,27 @@ func CheckForUpdate(client *api.Client, stateFilePath, repo, currentVersion stri return nil, nil } -func getLatestReleaseInfo(client *api.Client, repo string) (*ReleaseInfo, error) { - var latestRelease ReleaseInfo - err := client.REST(ghinstance.Default(), "GET", fmt.Sprintf("repos/%s/releases/latest", repo), nil, &latestRelease) +func getLatestReleaseInfo(ctx context.Context, client *http.Client, repo string) (*ReleaseInfo, error) { + req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo), nil) if err != nil { return nil, err } - + res, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { + _, _ = io.Copy(io.Discard, res.Body) + res.Body.Close() + }() + if res.StatusCode != 200 { + return nil, fmt.Errorf("unexpected HTTP %d", res.StatusCode) + } + dec := json.NewDecoder(res.Body) + var latestRelease ReleaseInfo + if err := dec.Decode(&latestRelease); err != nil { + return nil, err + } return &latestRelease, nil } diff --git a/internal/update/update_test.go b/internal/update/update_test.go index 96dce4f2a..bb514adfc 100644 --- a/internal/update/update_test.go +++ b/internal/update/update_test.go @@ -1,13 +1,13 @@ package update import ( + "context" "fmt" "log" "net/http" "os" "testing" - "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/pkg/httpmock" ) @@ -75,7 +75,6 @@ func TestCheckForUpdate(t *testing.T) { reg := &httpmock.Registry{} httpClient := &http.Client{} httpmock.ReplaceTripper(httpClient, reg) - client := api.NewClientFromHTTP(httpClient) reg.Register( httpmock.REST("GET", "repos/OWNER/REPO/releases/latest"), @@ -85,7 +84,7 @@ func TestCheckForUpdate(t *testing.T) { }`, s.LatestVersion, s.LatestURL)), ) - rel, err := CheckForUpdate(client, tempFilePath(), "OWNER/REPO", s.CurrentVersion) + rel, err := CheckForUpdate(context.TODO(), httpClient, tempFilePath(), "OWNER/REPO", s.CurrentVersion) if err != nil { t.Fatal(err) } diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index 00e610dcd..643f14094 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -7,7 +7,9 @@ import ( "io" "net/http" "os" + "path/filepath" "regexp" + "runtime" "sort" "strings" "time" @@ -183,6 +185,10 @@ func NewCmdApi(f *cmdutil.Factory, runF func(*ApiOptions) error) *cobra.Command opts.RequestPath = args[0] opts.RequestMethodPassed = c.Flags().Changed("method") + if runtime.GOOS == "windows" && filepath.IsAbs(opts.RequestPath) { + return fmt.Errorf(`invalid API endpoint: "%s". Your shell might be rewriting URL paths as filesystem paths. To avoid this, omit the leading slash from the endpoint argument`, opts.RequestPath) + } + if c.Flags().Changed("hostname") { if err := ghinstance.HostnameValidator(opts.Hostname); err != nil { return cmdutil.FlagErrorf("error parsing `--hostname`: %w", err) diff --git a/pkg/cmd/api/api_test.go b/pkg/cmd/api/api_test.go index b9cd2777f..27a47dd92 100644 --- a/pkg/cmd/api/api_test.go +++ b/pkg/cmd/api/api_test.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "path/filepath" + "runtime" "strings" "testing" "time" @@ -355,6 +356,20 @@ func Test_NewCmdApi(t *testing.T) { } } +func Test_NewCmdApi_WindowsAbsPath(t *testing.T) { + if runtime.GOOS != "windows" { + t.SkipNow() + } + + cmd := NewCmdApi(&cmdutil.Factory{}, func(opts *ApiOptions) error { + return nil + }) + + cmd.SetArgs([]string{`C:\users\repos`}) + _, err := cmd.ExecuteC() + assert.EqualError(t, err, `invalid API endpoint: "C:\users\repos". Your shell might be rewriting URL paths as filesystem paths. To avoid this, omit the leading slash from the endpoint argument`) +} + func Test_apiRun(t *testing.T) { tests := []struct { name string diff --git a/pkg/cmd/auth/shared/oauth_scopes.go b/pkg/cmd/auth/shared/oauth_scopes.go index 507425167..8d9996019 100644 --- a/pkg/cmd/auth/shared/oauth_scopes.go +++ b/pkg/cmd/auth/shared/oauth_scopes.go @@ -31,6 +31,7 @@ type httpClient interface { Do(*http.Request) (*http.Response, error) } +// GetScopes performs a GitHub API request and returns the value of the X-Oauth-Scopes header. func GetScopes(httpClient httpClient, hostname, authToken string) (string, error) { apiEndpoint := ghinstance.RESTPrefix(hostname) @@ -60,12 +61,20 @@ func GetScopes(httpClient httpClient, hostname, authToken string) (string, error return res.Header.Get("X-Oauth-Scopes"), nil } +// HasMinimumScopes performs a GitHub API request and returns an error if the token used in the request +// lacks the minimum required scopes for performing API operations with gh. func HasMinimumScopes(httpClient httpClient, hostname, authToken string) error { scopesHeader, err := GetScopes(httpClient, hostname, authToken) if err != nil { return err } + return HeaderHasMinimumScopes(scopesHeader) +} + +// HeaderHasMinimumScopes parses the comma separated scopesHeader string and returns an error +// if it lacks the minimum required scopes for performing API operations with gh. +func HeaderHasMinimumScopes(scopesHeader string) error { if scopesHeader == "" { // if the token reports no scopes, assume that it's an integration token and give up on // detecting its capabilities diff --git a/pkg/cmd/auth/shared/oauth_scopes_test.go b/pkg/cmd/auth/shared/oauth_scopes_test.go index 8f4382d9f..b1ea4c601 100644 --- a/pkg/cmd/auth/shared/oauth_scopes_test.go +++ b/pkg/cmd/auth/shared/oauth_scopes_test.go @@ -11,6 +11,53 @@ import ( ) func Test_HasMinimumScopes(t *testing.T) { + tests := []struct { + name string + header string + wantErr string + }{ + { + name: "write:org satisfies read:org", + header: "repo, write:org", + wantErr: "", + }, + { + name: "insufficient scope", + header: "repo", + wantErr: "missing required scope 'read:org'", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakehttp := &httpmock.Registry{} + defer fakehttp.Verify(t) + + var gotAuthorization string + fakehttp.Register(httpmock.REST("GET", ""), func(req *http.Request) (*http.Response, error) { + gotAuthorization = req.Header.Get("authorization") + return &http.Response{ + Request: req, + StatusCode: 200, + Body: io.NopCloser(&bytes.Buffer{}), + Header: map[string][]string{ + "X-Oauth-Scopes": {tt.header}, + }, + }, nil + }) + + client := http.Client{Transport: fakehttp} + err := HasMinimumScopes(&client, "github.com", "ATOKEN") + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, gotAuthorization, "token ATOKEN") + }) + } +} + +func Test_HeaderHasMinimumScopes(t *testing.T) { tests := []struct { name string header string @@ -49,31 +96,13 @@ func Test_HasMinimumScopes(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - fakehttp := &httpmock.Registry{} - defer fakehttp.Verify(t) - var gotAuthorization string - fakehttp.Register(httpmock.REST("GET", ""), func(req *http.Request) (*http.Response, error) { - gotAuthorization = req.Header.Get("authorization") - return &http.Response{ - Request: req, - StatusCode: 200, - Body: io.NopCloser(&bytes.Buffer{}), - Header: map[string][]string{ - "X-Oauth-Scopes": {tt.header}, - }, - }, nil - }) - - client := http.Client{Transport: fakehttp} - err := HasMinimumScopes(&client, "github.com", "ATOKEN") + err := HeaderHasMinimumScopes(tt.header) if tt.wantErr != "" { assert.EqualError(t, err, tt.wantErr) } else { assert.NoError(t, err) } - assert.Equal(t, gotAuthorization, "token ATOKEN") }) } - } diff --git a/pkg/cmd/auth/status/status.go b/pkg/cmd/auth/status/status.go index 750c74e29..c8c4f0ee6 100644 --- a/pkg/cmd/auth/status/status.go +++ b/pkg/cmd/auth/status/status.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "net/http" + "path/filepath" + "strings" "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" @@ -90,6 +92,11 @@ func statusRun(opts *StatusOptions) error { isHostnameFound = true token, tokenSource := cfg.AuthToken(hostname) + if tokenSource == "oauth_token" { + // The go-gh function TokenForHost returns this value as source for tokens read from the + // config file, but we want the file path instead. This attempts to reconstruct it. + tokenSource = filepath.Join(config.ConfigDir(), "hosts.yml") + } _, tokenIsWriteable := shared.AuthTokenWriteable(cfg, hostname) statusInfo[hostname] = []string{} @@ -97,24 +104,29 @@ func statusRun(opts *StatusOptions) error { statusInfo[hostname] = append(statusInfo[hostname], fmt.Sprintf(x, ys...)) } - if err := shared.HasMinimumScopes(httpClient, hostname, token); err != nil { + scopesHeader, err := shared.GetScopes(httpClient, hostname, token) + if err != nil { + addMsg("%s %s: authentication failed", cs.Red("X"), hostname) + addMsg("- The %s token in %s is no longer valid.", cs.Bold(hostname), tokenSource) + if tokenIsWriteable { + addMsg("- To re-authenticate, run: %s %s", + cs.Bold("gh auth login -h"), cs.Bold(hostname)) + addMsg("- To forget about this host, run: %s %s", + cs.Bold("gh auth logout -h"), cs.Bold(hostname)) + } + failed = true + continue + } + + if err := shared.HeaderHasMinimumScopes(scopesHeader); err != nil { var missingScopes *shared.MissingScopesError if errors.As(err, &missingScopes) { addMsg("%s %s: the token in %s is %s", cs.Red("X"), hostname, tokenSource, err) if tokenIsWriteable { - addMsg("- To request missing scopes, run: %s %s\n", + addMsg("- To request missing scopes, run: %s %s", cs.Bold("gh auth refresh -h"), cs.Bold(hostname)) } - } else { - addMsg("%s %s: authentication failed", cs.Red("X"), hostname) - addMsg("- The %s token in %s is no longer valid.", cs.Bold(hostname), tokenSource) - if tokenIsWriteable { - addMsg("- To re-authenticate, run: %s %s", - cs.Bold("gh auth login -h"), cs.Bold(hostname)) - addMsg("- To forget about this host, run: %s %s", - cs.Bold("gh auth logout -h"), cs.Bold(hostname)) - } } failed = true } else { @@ -122,23 +134,23 @@ func statusRun(opts *StatusOptions) error { username, err := api.CurrentLoginName(apiClient, hostname) if err != nil { addMsg("%s %s: api call failed: %s", cs.Red("X"), hostname, err) + failed = true } + addMsg("%s Logged in to %s as %s (%s)", cs.SuccessIcon(), hostname, cs.Bold(username), tokenSource) proto, _ := cfg.GetOrDefault(hostname, "git_protocol") if proto != "" { addMsg("%s Git operations for %s configured to use %s protocol.", cs.SuccessIcon(), hostname, cs.Bold(proto)) } - tokenDisplay := "*******************" - if opts.ShowToken { - tokenDisplay = token - } - addMsg("%s Token: %s", cs.SuccessIcon(), tokenDisplay) - } - addMsg("") + addMsg("%s Token: %s", cs.SuccessIcon(), displayToken(token, opts.ShowToken)) - // NB we could take this opportunity to add or fix the "user" key in the hosts config. I chose - // not to since I wanted this command to be read-only. + if scopesHeader != "" { + addMsg("%s Token scopes: %s", cs.SuccessIcon(), scopesHeader) + } else if expectScopes(token) { + addMsg("%s Token scopes: none", cs.Red("X")) + } + } } if !isHostnameFound { @@ -147,11 +159,16 @@ func statusRun(opts *StatusOptions) error { return cmdutil.SilentError } + prevEntry := false for _, hostname := range hostnames { lines, ok := statusInfo[hostname] if !ok { continue } + if prevEntry { + fmt.Fprint(stderr, "\n") + } + prevEntry = true fmt.Fprintf(stderr, "%s\n", cs.Bold(hostname)) for _, line := range lines { fmt.Fprintf(stderr, " %s\n", line) @@ -164,3 +181,20 @@ func statusRun(opts *StatusOptions) error { return nil } + +func displayToken(token string, printRaw bool) string { + if printRaw { + return token + } + + if idx := strings.LastIndexByte(token, '_'); idx > -1 { + prefix := token[0 : idx+1] + return prefix + strings.Repeat("*", len(token)-len(prefix)) + } + + return strings.Repeat("*", len(token)) +} + +func expectScopes(token string) bool { + return strings.HasPrefix(token, "ghp_") || strings.HasPrefix(token, "gho_") +} diff --git a/pkg/cmd/auth/status/status_test.go b/pkg/cmd/auth/status/status_test.go index 33aded088..e169d1b3f 100644 --- a/pkg/cmd/auth/status/status_test.go +++ b/pkg/cmd/auth/status/status_test.go @@ -3,9 +3,11 @@ package status import ( "bytes" "net/http" - "regexp" + "path/filepath" + "strings" "testing" + "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/httpmock" @@ -74,12 +76,12 @@ func Test_statusRun(t *testing.T) { readConfigs := config.StubWriteConfig(t) tests := []struct { - name string - opts *StatusOptions - httpStubs func(*httpmock.Registry) - cfgStubs func(*config.ConfigMock) - wantErr string - wantErrOut *regexp.Regexp + name string + opts *StatusOptions + httpStubs func(*httpmock.Registry) + cfgStubs func(*config.ConfigMock) + wantErr string + wantOut string }{ { name: "hostname set", @@ -91,12 +93,20 @@ func Test_statusRun(t *testing.T) { c.Set("github.com", "oauth_token", "abc123") }, httpStubs: func(reg *httpmock.Registry) { + // mocks for HeaderHasMinimumScopes api requests to a non-github.com host reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org")) + // mock for CurrentLoginName reg.Register( httpmock.GraphQL(`query UserCurrent\b`), httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`)) }, - wantErrOut: regexp.MustCompile(`Logged in to joel.miller as.*tess`), + wantOut: heredoc.Doc(` + joel.miller + ✓ Logged in to joel.miller as tess (GH_CONFIG_DIR/hosts.yml) + ✓ Git operations for joel.miller configured to use https protocol. + ✓ Token: ****** + ✓ Token scopes: repo,read:org + `), }, { name: "missing scope", @@ -106,14 +116,27 @@ func Test_statusRun(t *testing.T) { c.Set("github.com", "oauth_token", "abc123") }, httpStubs: func(reg *httpmock.Registry) { + // mocks for HeaderHasMinimumScopes api requests to a non-github.com host reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo")) + // mocks for HeaderHasMinimumScopes api requests to github.com host reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,read:org")) + // mock for CurrentLoginName reg.Register( httpmock.GraphQL(`query UserCurrent\b`), httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`)) }, - wantErrOut: regexp.MustCompile(`joel.miller: missing required.*Logged in to github.com as.*tess`), - wantErr: "SilentError", + wantErr: "SilentError", + wantOut: heredoc.Doc(` + joel.miller + X joel.miller: the token in GH_CONFIG_DIR/hosts.yml is missing required scope 'read:org' + - To request missing scopes, run: gh auth refresh -h joel.miller + + github.com + ✓ Logged in to github.com as tess (GH_CONFIG_DIR/hosts.yml) + ✓ Git operations for github.com configured to use https protocol. + ✓ Token: ****** + ✓ Token scopes: repo,read:org + `), }, { name: "bad token", @@ -123,25 +146,47 @@ func Test_statusRun(t *testing.T) { c.Set("github.com", "oauth_token", "abc123") }, httpStubs: func(reg *httpmock.Registry) { + // mock for HeaderHasMinimumScopes api requests to a non-github.com host reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.StatusStringResponse(400, "no bueno")) + // mock for HeaderHasMinimumScopes api requests to github.com reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,read:org")) + // mock for CurrentLoginName reg.Register( httpmock.GraphQL(`query UserCurrent\b`), httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`)) }, - wantErrOut: regexp.MustCompile(`joel.miller: authentication failed.*Logged in to github.com as.*tess`), - wantErr: "SilentError", + wantErr: "SilentError", + wantOut: heredoc.Doc(` + joel.miller + X joel.miller: authentication failed + - The joel.miller token in GH_CONFIG_DIR/hosts.yml is no longer valid. + - To re-authenticate, run: gh auth login -h joel.miller + - To forget about this host, run: gh auth logout -h joel.miller + + github.com + ✓ Logged in to github.com as tess (GH_CONFIG_DIR/hosts.yml) + ✓ Git operations for github.com configured to use https protocol. + ✓ Token: ****** + ✓ Token scopes: repo,read:org + `), }, { name: "all good", opts: &StatusOptions{}, cfgStubs: func(c *config.ConfigMock) { - c.Set("github.com", "oauth_token", "abc123") - c.Set("joel.miller", "oauth_token", "abc123") + c.Set("github.com", "oauth_token", "gho_abc123") + c.Set("joel.miller", "oauth_token", "gho_abc123") }, httpStubs: func(reg *httpmock.Registry) { - reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org")) - reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,read:org")) + // mocks for HeaderHasMinimumScopes api requests to github.com + reg.Register( + httpmock.REST("GET", ""), + httpmock.WithHeader(httpmock.ScopesResponder("repo,read:org"), "X-Oauth-Scopes", "repo, read:org")) + // mocks for HeaderHasMinimumScopes api requests to a non-github.com host + reg.Register( + httpmock.REST("GET", "api/v3/"), + httpmock.WithHeader(httpmock.ScopesResponder("repo,read:org"), "X-Oauth-Scopes", "")) + // mock for CurrentLoginName, one for each host reg.Register( httpmock.GraphQL(`query UserCurrent\b`), httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`)) @@ -149,26 +194,65 @@ func Test_statusRun(t *testing.T) { httpmock.GraphQL(`query UserCurrent\b`), httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`)) }, - wantErrOut: regexp.MustCompile(`(?s)Logged in to github.com as.*tess.*Logged in to joel.miller as.*tess`), + wantOut: heredoc.Doc(` + github.com + ✓ Logged in to github.com as tess (GH_CONFIG_DIR/hosts.yml) + ✓ Git operations for github.com configured to use https protocol. + ✓ Token: gho_****** + ✓ Token scopes: repo, read:org + + joel.miller + ✓ Logged in to joel.miller as tess (GH_CONFIG_DIR/hosts.yml) + ✓ Git operations for joel.miller configured to use https protocol. + ✓ Token: gho_****** + X Token scopes: none + `), }, { - name: "hide token", + name: "server-to-server token", opts: &StatusOptions{}, cfgStubs: func(c *config.ConfigMock) { - c.Set("joel.miller", "oauth_token", "abc123") - c.Set("github.com", "oauth_token", "xyz456") + c.Set("github.com", "oauth_token", "ghs_xxx") }, httpStubs: func(reg *httpmock.Registry) { - reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org")) - reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,read:org")) + // mocks for HeaderHasMinimumScopes api requests to github.com reg.Register( - httpmock.GraphQL(`query UserCurrent\b`), - httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`)) + httpmock.REST("GET", ""), + httpmock.ScopesResponder("")) + // mock for CurrentLoginName reg.Register( httpmock.GraphQL(`query UserCurrent\b`), httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`)) }, - wantErrOut: regexp.MustCompile(`(?s)Token: \*{19}.*Token: \*{19}`), + wantOut: heredoc.Doc(` + github.com + ✓ Logged in to github.com as tess (GH_CONFIG_DIR/hosts.yml) + ✓ Git operations for github.com configured to use https protocol. + ✓ Token: ghs_*** + `), + }, + { + name: "PAT V2 token", + opts: &StatusOptions{}, + cfgStubs: func(c *config.ConfigMock) { + c.Set("github.com", "oauth_token", "github_pat_xxx") + }, + httpStubs: func(reg *httpmock.Registry) { + // mocks for HeaderHasMinimumScopes api requests to github.com + reg.Register( + httpmock.REST("GET", ""), + httpmock.ScopesResponder("")) + // mock for CurrentLoginName + reg.Register( + httpmock.GraphQL(`query UserCurrent\b`), + httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`)) + }, + wantOut: heredoc.Doc(` + github.com + ✓ Logged in to github.com as tess (GH_CONFIG_DIR/hosts.yml) + ✓ Git operations for github.com configured to use https protocol. + ✓ Token: github_pat_*** + `), }, { name: "show token", @@ -180,8 +264,11 @@ func Test_statusRun(t *testing.T) { c.Set("joel.miller", "oauth_token", "abc123") }, httpStubs: func(reg *httpmock.Registry) { + // mocks for HeaderHasMinimumScopes on a non-github.com host reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org")) + // mocks for HeaderHasMinimumScopes on github.com reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,read:org")) + // mock for CurrentLoginName, one for each host reg.Register( httpmock.GraphQL(`query UserCurrent\b`), httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`)) @@ -189,7 +276,19 @@ func Test_statusRun(t *testing.T) { httpmock.GraphQL(`query UserCurrent\b`), httpmock.StringResponse(`{"data":{"viewer":{"login":"tess"}}}`)) }, - wantErrOut: regexp.MustCompile(`(?s)Token: xyz456.*Token: abc123`), + wantOut: heredoc.Doc(` + github.com + ✓ Logged in to github.com as tess (GH_CONFIG_DIR/hosts.yml) + ✓ Git operations for github.com configured to use https protocol. + ✓ Token: xyz456 + ✓ Token scopes: repo,read:org + + joel.miller + ✓ Logged in to joel.miller as tess (GH_CONFIG_DIR/hosts.yml) + ✓ Git operations for joel.miller configured to use https protocol. + ✓ Token: abc123 + ✓ Token scopes: repo,read:org + `), }, { name: "missing hostname", @@ -199,9 +298,9 @@ func Test_statusRun(t *testing.T) { cfgStubs: func(c *config.ConfigMock) { c.Set("github.com", "oauth_token", "abc123") }, - httpStubs: func(reg *httpmock.Registry) {}, - wantErrOut: regexp.MustCompile(`(?s)Hostname "github.example.com" not found among authenticated GitHub hosts`), - wantErr: "SilentError", + httpStubs: func(reg *httpmock.Registry) {}, + wantErr: "SilentError", + wantOut: "Hostname \"github.example.com\" not found among authenticated GitHub hosts\n", }, } @@ -227,6 +326,7 @@ func Test_statusRun(t *testing.T) { } reg := &httpmock.Registry{} + defer reg.Verify(t) tt.opts.HttpClient = func() (*http.Client, error) { return &http.Client{Transport: reg}, nil } @@ -237,16 +337,12 @@ func Test_statusRun(t *testing.T) { err := statusRun(tt.opts) if tt.wantErr != "" { assert.EqualError(t, err, tt.wantErr) - return } else { assert.NoError(t, err) } - if tt.wantErrOut == nil { - assert.Equal(t, "", stderr.String()) - } else { - assert.True(t, tt.wantErrOut.MatchString(stderr.String())) - } + output := strings.ReplaceAll(stderr.String(), config.ConfigDir()+string(filepath.Separator), "GH_CONFIG_DIR/") + assert.Equal(t, tt.wantOut, output) mainBuf := bytes.Buffer{} hostsBuf := bytes.Buffer{} @@ -254,8 +350,6 @@ func Test_statusRun(t *testing.T) { assert.Equal(t, "", mainBuf.String()) assert.Equal(t, "", hostsBuf.String()) - - reg.Verify(t) }) } } diff --git a/pkg/cmd/auth/token/token.go b/pkg/cmd/auth/token/token.go index 9dcc65e8d..c28d42d26 100644 --- a/pkg/cmd/auth/token/token.go +++ b/pkg/cmd/auth/token/token.go @@ -52,9 +52,8 @@ func tokenRun(opts *TokenOptions) error { return err } - key := "oauth_token" - val, err := cfg.GetOrDefault(hostname, key) - if err != nil { + val, _ := cfg.AuthToken(hostname) + if val == "" { return fmt.Errorf("no oauth token") } diff --git a/pkg/cmd/browse/browse.go b/pkg/cmd/browse/browse.go index de18ad416..524f9f20e 100644 --- a/pkg/cmd/browse/browse.go +++ b/pkg/cmd/browse/browse.go @@ -35,6 +35,7 @@ type BrowseOptions struct { Branch string CommitFlag bool ProjectsFlag bool + ReleasesFlag bool SettingsFlag bool WikiFlag bool NoBrowserFlag bool @@ -94,12 +95,13 @@ func NewCmdBrowse(f *cmdutil.Factory, runF func(*BrowseOptions) error) *cobra.Co } if err := cmdutil.MutuallyExclusive( - "specify only one of `--branch`, `--commit`, `--projects`, `--wiki`, or `--settings`", + "specify only one of `--branch`, `--commit`, `--releases`, `--projects`, `--wiki`, or `--settings`", opts.Branch != "", opts.CommitFlag, opts.WikiFlag, opts.SettingsFlag, opts.ProjectsFlag, + opts.ReleasesFlag, ); err != nil { return err } @@ -116,6 +118,7 @@ func NewCmdBrowse(f *cmdutil.Factory, runF func(*BrowseOptions) error) *cobra.Co cmdutil.EnableRepoOverride(cmd, f) cmd.Flags().BoolVarP(&opts.ProjectsFlag, "projects", "p", false, "Open repository projects") + cmd.Flags().BoolVarP(&opts.ReleasesFlag, "releases", "r", false, "Open repository releases") cmd.Flags().BoolVarP(&opts.WikiFlag, "wiki", "w", false, "Open repository wiki") cmd.Flags().BoolVarP(&opts.SettingsFlag, "settings", "s", false, "Open repository settings") cmd.Flags().BoolVarP(&opts.NoBrowserFlag, "no-browser", "n", false, "Print destination URL instead of opening the browser") @@ -160,6 +163,8 @@ func parseSection(baseRepo ghrepo.Interface, opts *BrowseOptions) (string, error if opts.SelectorArg == "" { if opts.ProjectsFlag { return "projects", nil + } else if opts.ReleasesFlag { + return "releases", nil } else if opts.SettingsFlag { return "settings", nil } else if opts.WikiFlag { diff --git a/pkg/cmd/browse/browse_test.go b/pkg/cmd/browse/browse_test.go index 75e615be0..3a283f8f4 100644 --- a/pkg/cmd/browse/browse_test.go +++ b/pkg/cmd/browse/browse_test.go @@ -47,6 +47,14 @@ func TestNewCmdBrowse(t *testing.T) { }, wantsErr: false, }, + { + name: "releases flag", + cli: "--releases", + wants: BrowseOptions{ + ReleasesFlag: true, + }, + wantsErr: false, + }, { name: "wiki flag", cli: "--wiki", @@ -141,6 +149,7 @@ func TestNewCmdBrowse(t *testing.T) { assert.Equal(t, tt.wants.Branch, opts.Branch) assert.Equal(t, tt.wants.SelectorArg, opts.SelectorArg) assert.Equal(t, tt.wants.ProjectsFlag, opts.ProjectsFlag) + assert.Equal(t, tt.wants.ReleasesFlag, opts.ReleasesFlag) assert.Equal(t, tt.wants.WikiFlag, opts.WikiFlag) assert.Equal(t, tt.wants.NoBrowserFlag, opts.NoBrowserFlag) assert.Equal(t, tt.wants.SettingsFlag, opts.SettingsFlag) @@ -190,6 +199,14 @@ func Test_runBrowse(t *testing.T) { baseRepo: ghrepo.New("ttran112", "7ate9"), expectedURL: "https://github.com/ttran112/7ate9/projects", }, + { + name: "releases flag", + opts: BrowseOptions{ + ReleasesFlag: true, + }, + baseRepo: ghrepo.New("ttran112", "7ate9"), + expectedURL: "https://github.com/ttran112/7ate9/releases", + }, { name: "wiki flag", opts: BrowseOptions{ diff --git a/pkg/cmd/codespace/code_test.go b/pkg/cmd/codespace/code_test.go index 26aa05d4c..f43d8a20c 100644 --- a/pkg/cmd/codespace/code_test.go +++ b/pkg/cmd/codespace/code_test.go @@ -97,7 +97,7 @@ func TestPendingOperationDisallowsCode(t *testing.T) { func testingCodeApp() *App { ios, _, _, _ := iostreams.Test() - return NewApp(ios, nil, testCodeApiMock(), nil) + return NewApp(ios, nil, testCodeApiMock(), nil, nil) } func testCodeApiMock() *apiClientMock { diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index 21c6ec358..69600fa56 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -13,6 +13,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/AlecAivazis/survey/v2/terminal" + clicontext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/internal/browser" "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/api" @@ -32,9 +33,10 @@ type App struct { errLogger *log.Logger executable executable browser browser.Browser + remotes func() (clicontext.Remotes, error) } -func NewApp(io *iostreams.IOStreams, exe executable, apiClient apiClient, browser browser.Browser) *App { +func NewApp(io *iostreams.IOStreams, exe executable, apiClient apiClient, browser browser.Browser, remotes func() (clicontext.Remotes, error)) *App { errLogger := log.New(io.ErrOut, "", 0) return &App{ @@ -43,6 +45,7 @@ func NewApp(io *iostreams.IOStreams, exe executable, apiClient apiClient, browse errLogger: errLogger, executable: exe, browser: browser, + remotes: remotes, } } @@ -84,6 +87,7 @@ func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App //go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient type apiClient interface { + GetUser(ctx context.Context) (*api.User, error) GetCodespace(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error) GetOrgMemberCodespace(ctx context.Context, orgName string, userName string, codespaceName string) (*api.Codespace, error) ListCodespaces(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) diff --git a/pkg/cmd/codespace/create.go b/pkg/cmd/codespace/create.go index f8a211adb..38908ce3b 100644 --- a/pkg/cmd/codespace/create.go +++ b/pkg/cmd/codespace/create.go @@ -10,6 +10,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/text" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/spf13/cobra" @@ -119,12 +120,24 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { promptForRepoAndBranch := userInputs.Repository == "" if promptForRepoAndBranch { + var defaultRepo string + if remotes, _ := a.remotes(); remotes != nil { + if defaultRemote, _ := remotes.ResolvedRemote(); defaultRemote != nil { + // this is a remote explicitly chosen via `repo set-default` + defaultRepo = ghrepo.FullName(defaultRemote) + } else if len(remotes) > 0 { + // as a fallback, just pick the first remote + defaultRepo = ghrepo.FullName(remotes[0]) + } + } + repoQuestions := []*survey.Question{ { Name: "repository", Prompt: &survey.Input{ Message: "Repository:", Help: "Search for repos by name. To search within an org or user, or to see private repos, enter at least ':user/'.", + Default: defaultRepo, Suggest: func(toComplete string) []string { return getRepoSuggestions(ctx, a.apiClient, toComplete) }, @@ -157,7 +170,7 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { }) if err != nil { return fmt.Errorf("error checking codespace ownership: %w", err) - } else if billableOwner != nil && billableOwner.Type == "Organization" { + } else if billableOwner != nil && (billableOwner.Type == "Organization" || billableOwner.Type == "User") { cs := a.io.ColorScheme() fmt.Fprintln(a.io.ErrOut, cs.Blue(" ✓ Codespaces usage for this repository is paid for by "+billableOwner.Login)) } diff --git a/pkg/cmd/codespace/create_test.go b/pkg/cmd/codespace/create_test.go index f8f27b79f..4f084ce12 100644 --- a/pkg/cmd/codespace/create_test.go +++ b/pkg/cmd/codespace/create_test.go @@ -57,6 +57,7 @@ func TestApp_Create(t *testing.T) { retentionPeriod: NullableDuration{durationPtr(48 * time.Hour)}, }, wantStdout: "monalisa-dotfiles-abcd1234\n", + wantStderr: " ✓ Codespaces usage for this repository is paid for by monalisa\n", }, { name: "create with explicit display name", @@ -78,6 +79,7 @@ func TestApp_Create(t *testing.T) { displayName: "funky flute", }, wantStdout: "monalisa-dotfiles-abcd1234\n", + wantStderr: " ✓ Codespaces usage for this repository is paid for by monalisa\n", }, { name: "create codespace with default branch shows idle timeout notice if present", @@ -111,6 +113,7 @@ func TestApp_Create(t *testing.T) { devContainerPath: ".devcontainer/foobar/devcontainer.json", }, wantStdout: "monalisa-dotfiles-abcd1234\n", + wantStderr: " ✓ Codespaces usage for this repository is paid for by monalisa\n", }, { name: "create codespace with devcontainer path results in selecting the correct machine type", @@ -172,6 +175,7 @@ func TestApp_Create(t *testing.T) { devContainerPath: ".devcontainer/foobar/devcontainer.json", }, wantStdout: "monalisa-dotfiles-abcd1234\n", + wantStderr: " ✓ Codespaces usage for this repository is paid for by monalisa\n", }, { name: "create codespace with default branch with default devcontainer if no path provided and no devcontainer files exist in the repo", @@ -205,7 +209,7 @@ func TestApp_Create(t *testing.T) { idleTimeout: 30 * time.Minute, }, wantStdout: "monalisa-dotfiles-abcd1234\n", - wantStderr: "Notice: Idle timeout for this codespace is set to 10 minutes in compliance with your organization's policy\n", + wantStderr: " ✓ Codespaces usage for this repository is paid for by monalisa\nNotice: Idle timeout for this codespace is set to 10 minutes in compliance with your organization's policy\n", isTTY: true, }, { @@ -224,7 +228,8 @@ func TestApp_Create(t *testing.T) { showStatus: false, idleTimeout: 30 * time.Minute, }, - wantErr: fmt.Errorf("error getting devcontainer.json paths: some error"), + wantErr: fmt.Errorf("error getting devcontainer.json paths: some error"), + wantStderr: " ✓ Codespaces usage for this repository is paid for by monalisa\n", }, { name: "create codespace with default branch does not show idle timeout notice if not conntected to terminal", @@ -252,7 +257,7 @@ func TestApp_Create(t *testing.T) { idleTimeout: 30 * time.Minute, }, wantStdout: "monalisa-dotfiles-abcd1234\n", - wantStderr: "", + wantStderr: " ✓ Codespaces usage for this repository is paid for by monalisa\n", isTTY: false, }, { @@ -280,7 +285,8 @@ func TestApp_Create(t *testing.T) { idleTimeout: 30 * time.Minute, }, wantErr: cmdutil.SilentError, - wantStderr: `You must authorize or deny additional permissions requested by this codespace before continuing. + wantStderr: ` ✓ Codespaces usage for this repository is paid for by monalisa +You must authorize or deny additional permissions requested by this codespace before continuing. Open this URL in your browser to review and authorize additional permissions: example.com/permissions Alternatively, you can run "create" with the "--default-permissions" option to continue without authorizing additional permissions. `, @@ -304,7 +310,31 @@ Alternatively, you can run "create" with the "--default-permissions" option to c wantErr: fmt.Errorf("error checking codespace ownership: some error"), }, { - name: "mentions billable owner when org covers codepaces for a repository", + name: "mentions User as billable owner when org does not cover codepaces for a repository", + fields: fields{ + apiClient: apiCreateDefaults(&apiClientMock{ + GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { + return &api.User{ + Type: "User", + Login: "monalisa", + }, nil + }, + CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { + return &api.Codespace{ + Name: "monalisa-dotfiles-abcd1234", + }, nil + }, + }), + }, + opts: createOptions{ + repo: "monalisa/dotfiles", + branch: "main", + }, + wantStderr: " ✓ Codespaces usage for this repository is paid for by monalisa\n", + wantStdout: "monalisa-dotfiles-abcd1234\n", + }, + { + name: "mentions Organization as billable owner when org covers codepaces for a repository", fields: fields{ apiClient: apiCreateDefaults(&apiClientMock{ GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { @@ -330,6 +360,28 @@ Alternatively, you can run "create" with the "--default-permissions" option to c wantStderr: " ✓ Codespaces usage for this repository is paid for by megacorp\n", wantStdout: "megacorp-private-abcd1234\n", }, + { + name: "does not mention billable owner when not an expected type", + fields: fields{ + apiClient: apiCreateDefaults(&apiClientMock{ + GetCodespaceBillableOwnerFunc: func(ctx context.Context, nwo string) (*api.User, error) { + return &api.User{ + Type: "UnexpectedBillableOwnerType", + Login: "mega-owner", + }, nil + }, + CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { + return &api.Codespace{ + Name: "megacorp-private-abcd1234", + }, nil + }, + }), + }, + opts: createOptions{ + repo: "megacorp/private", + }, + wantStdout: "megacorp-private-abcd1234\n", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/cmd/codespace/delete.go b/pkg/cmd/codespace/delete.go index 10f01cd95..74a1dd298 100644 --- a/pkg/cmd/codespace/delete.go +++ b/pkg/cmd/codespace/delete.go @@ -83,11 +83,17 @@ func (a *App) Delete(ctx context.Context, opts deleteOptions) (err error) { var codespaces []*api.Codespace nameFilter := opts.codespaceName if nameFilter == "" { - var codespaces []*api.Codespace - err = a.RunWithProgress("Fetching codespaces", func() (err error) { - codespaces, err = a.apiClient.ListCodespaces(ctx, api.ListCodespacesOptions{OrgName: opts.orgName, UserName: opts.userName}) - return - }) + a.StartProgressIndicatorWithLabel("Fetching codespaces") + userName := opts.userName + if userName == "" && opts.orgName != "" { + currentUser, err := a.apiClient.GetUser(ctx) + if err != nil { + return err + } + userName = currentUser.Login + } + codespaces, err = a.apiClient.ListCodespaces(ctx, api.ListCodespacesOptions{OrgName: opts.orgName, UserName: userName}) + a.StopProgressIndicator() if err != nil { return fmt.Errorf("error getting codespaces: %w", err) } diff --git a/pkg/cmd/codespace/delete_test.go b/pkg/cmd/codespace/delete_test.go index 93f89c775..ca6fe989e 100644 --- a/pkg/cmd/codespace/delete_test.go +++ b/pkg/cmd/codespace/delete_test.go @@ -202,10 +202,28 @@ func TestDelete(t *testing.T) { wantStdout: "", wantErr: true, }, + { + name: "deletion for org codespace succeeds without username", + opts: deleteOptions{ + deleteAll: true, + orgName: "bookish", + }, + codespaces: []*api.Codespace{ + { + Name: "monalisa-spoonknife-123", + Owner: api.User{Login: "monalisa"}, + }, + }, + wantDeleted: []string{"monalisa-spoonknife-123"}, + wantStdout: "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { apiMock := &apiClientMock{ + GetUserFunc: func(_ context.Context) (*api.User, error) { + return &api.User{Login: "monalisa"}, nil + }, DeleteCodespaceFunc: func(_ context.Context, name string, orgName string, userName string) error { if tt.deleteErr != nil { return tt.deleteErr @@ -248,11 +266,16 @@ func TestDelete(t *testing.T) { ios, _, stdout, stderr := iostreams.Test() ios.SetStdinTTY(true) ios.SetStdoutTTY(true) - app := NewApp(ios, nil, apiMock, nil) + app := NewApp(ios, nil, apiMock, nil, nil) err := app.Delete(context.Background(), opts) if (err != nil) != tt.wantErr { t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr) } + for _, listArgs := range apiMock.ListCodespacesCalls() { + if listArgs.Opts.OrgName != "" && listArgs.Opts.UserName == "" { + t.Errorf("ListCodespaces() expected username option to be set") + } + } var gotDeleted []string for _, delArgs := range apiMock.DeleteCodespaceCalls() { gotDeleted = append(gotDeleted, delArgs.Name) diff --git a/pkg/cmd/codespace/edit_test.go b/pkg/cmd/codespace/edit_test.go index b5e0e3cd1..886d9e455 100644 --- a/pkg/cmd/codespace/edit_test.go +++ b/pkg/cmd/codespace/edit_test.go @@ -88,7 +88,7 @@ func TestEdit(t *testing.T) { } ios, _, stdout, stderr := iostreams.Test() - a := NewApp(ios, nil, apiMock, nil) + a := NewApp(ios, nil, apiMock, nil, nil) var err error if tt.cliArgs == nil { diff --git a/pkg/cmd/codespace/jupyter.go b/pkg/cmd/codespace/jupyter.go index d7544b9a9..f8cde0b67 100644 --- a/pkg/cmd/codespace/jupyter.go +++ b/pkg/cmd/codespace/jupyter.go @@ -6,6 +6,7 @@ import ( "net" "strings" + "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/rpc" "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" @@ -60,7 +61,7 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) { } // Pass 0 to pick a random port - listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0)) + listen, _, err := codespaces.ListenTCP(0) if err != nil { return err } diff --git a/pkg/cmd/codespace/logs.go b/pkg/cmd/codespace/logs.go index 13c59cf72..d2edcfe81 100644 --- a/pkg/cmd/codespace/logs.go +++ b/pkg/cmd/codespace/logs.go @@ -3,7 +3,6 @@ package codespace import ( "context" "fmt" - "net" "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/rpc" @@ -49,12 +48,11 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err defer safeClose(session, &err) // Ensure local port is listening before client (getPostCreateOutput) connects. - listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port + listen, localPort, err := codespaces.ListenTCP(0) if err != nil { return err } defer listen.Close() - localPort := listen.Addr().(*net.TCPAddr).Port remoteSSHServerPort, sshUser := 0, "" err = a.RunWithProgress("Fetching SSH Details", func() error { diff --git a/pkg/cmd/codespace/logs_test.go b/pkg/cmd/codespace/logs_test.go index bd4ea02f8..161657b4d 100644 --- a/pkg/cmd/codespace/logs_test.go +++ b/pkg/cmd/codespace/logs_test.go @@ -36,5 +36,5 @@ func testingLogsApp() *App { } ios, _, _, _ := iostreams.Test() - return NewApp(ios, nil, apiMock, nil) + return NewApp(ios, nil, apiMock, nil, nil) } diff --git a/pkg/cmd/codespace/mock_api.go b/pkg/cmd/codespace/mock_api.go index 772b87f56..0796679fc 100644 --- a/pkg/cmd/codespace/mock_api.go +++ b/pkg/cmd/codespace/mock_api.go @@ -46,6 +46,9 @@ import ( // GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { // panic("mock out the GetRepository method") // }, +// GetUserFunc: func(ctx context.Context) (*api.User, error) { +// panic("mock out the GetUser method") +// }, // ListCodespacesFunc: func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) { // panic("mock out the ListCodespaces method") // }, @@ -95,6 +98,9 @@ type apiClientMock struct { // GetRepositoryFunc mocks the GetRepository method. GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error) + // GetUserFunc mocks the GetUser method. + GetUserFunc func(ctx context.Context) (*api.User, error) + // ListCodespacesFunc mocks the ListCodespaces method. ListCodespacesFunc func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) @@ -201,6 +207,11 @@ type apiClientMock struct { // Nwo is the nwo argument value. Nwo string } + // GetUser holds details about calls to the GetUser method. + GetUser []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } // ListCodespaces holds details about calls to the ListCodespaces method. ListCodespaces []struct { // Ctx is the ctx argument value. @@ -248,6 +259,7 @@ type apiClientMock struct { lockGetCodespacesMachines sync.RWMutex lockGetOrgMemberCodespace sync.RWMutex lockGetRepository sync.RWMutex + lockGetUser sync.RWMutex lockListCodespaces sync.RWMutex lockListDevContainers sync.RWMutex lockStartCodespace sync.RWMutex @@ -658,6 +670,38 @@ func (mock *apiClientMock) GetRepositoryCalls() []struct { return calls } +// GetUser calls GetUserFunc. +func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) { + if mock.GetUserFunc == nil { + panic("apiClientMock.GetUserFunc: method is nil but apiClient.GetUser was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockGetUser.Lock() + mock.calls.GetUser = append(mock.calls.GetUser, callInfo) + mock.lockGetUser.Unlock() + return mock.GetUserFunc(ctx) +} + +// GetUserCalls gets all the calls that were made to GetUser. +// Check the length with: +// +// len(mockedapiClient.GetUserCalls()) +func (mock *apiClientMock) GetUserCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockGetUser.RLock() + calls = mock.calls.GetUser + mock.lockGetUser.RUnlock() + return calls +} + // ListCodespaces calls ListCodespacesFunc. func (mock *apiClientMock) ListCodespaces(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) { if mock.ListCodespacesFunc == nil { diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index 65811da45..dbc5f31aa 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "net" "net/http" "strconv" "strings" @@ -392,7 +391,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st for _, pair := range portPairs { pair := pair group.Go(func() error { - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local)) + listen, _, err := codespaces.ListenTCP(pair.local) if err != nil { return err } diff --git a/pkg/cmd/codespace/ports_test.go b/pkg/cmd/codespace/ports_test.go index 3d3c87d95..ea61b11a5 100644 --- a/pkg/cmd/codespace/ports_test.go +++ b/pkg/cmd/codespace/ports_test.go @@ -263,5 +263,5 @@ func testingPortsApp() *App { ios, _, _, _ := iostreams.Test() - return NewApp(ios, nil, apiMock, nil) + return NewApp(ios, nil, apiMock, nil, nil) } diff --git a/pkg/cmd/codespace/rebuild_test.go b/pkg/cmd/codespace/rebuild_test.go index ec8b112de..f2496d089 100644 --- a/pkg/cmd/codespace/rebuild_test.go +++ b/pkg/cmd/codespace/rebuild_test.go @@ -32,5 +32,5 @@ func testingRebuildApp(mockCodespace api.Codespace) *App { } ios, _, _, _ := iostreams.Test() - return NewApp(ios, nil, apiMock, nil) + return NewApp(ios, nil, apiMock, nil, nil) } diff --git a/pkg/cmd/codespace/select_test.go b/pkg/cmd/codespace/select_test.go index c21a876fb..e97a7720e 100644 --- a/pkg/cmd/codespace/select_test.go +++ b/pkg/cmd/codespace/select_test.go @@ -47,7 +47,7 @@ func TestApp_Select(t *testing.T) { ios, _, stdout, stderr := iostreams.Test() ios.SetStdinTTY(true) ios.SetStdoutTTY(true) - a := NewApp(ios, nil, testSelectApiMock(), nil) + a := NewApp(ios, nil, testSelectApiMock(), nil, nil) opts := selectOptions{} if tt.outputToFile { diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 49ebe2c1c..aa410cd1a 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -6,9 +6,7 @@ import ( "context" "errors" "fmt" - "io" "log" - "net" "os" "os/exec" "path" @@ -190,7 +188,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e if opts.stdio { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true) - stdio := newReadWriteCloser(os.Stdin, os.Stdout) + stdio := liveshare.NewReadWriteHalfCloser(os.Stdin, os.Stdout) err := fwd.Forward(ctx, stdio) // always non-nil return fmt.Errorf("tunnel closed: %w", err) } @@ -201,12 +199,11 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e // Ensure local port is listening before client (Shell) connects. // Unless the user specifies a server port, localSSHServerPort is 0 // and thus the client will pick a random port. - listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", localSSHServerPort)) + listen, localSSHServerPort, err := codespaces.ListenTCP(localSSHServerPort) if err != nil { return err } defer listen.Close() - localSSHServerPort = listen.Addr().(*net.TCPAddr).Port connectDestination := opts.profile if connectDestination == "" { @@ -748,21 +745,3 @@ func (fl *fileLogger) Name() string { func (fl *fileLogger) Close() error { return fl.f.Close() } - -type combinedReadWriteCloser struct { - io.ReadCloser - io.WriteCloser -} - -func newReadWriteCloser(reader io.ReadCloser, writer io.WriteCloser) io.ReadWriteCloser { - return &combinedReadWriteCloser{reader, writer} -} - -func (crwc *combinedReadWriteCloser) Close() error { - werr := crwc.WriteCloser.Close() - rerr := crwc.ReadCloser.Close() - if werr != nil { - return werr - } - return rerr -} diff --git a/pkg/cmd/codespace/ssh_test.go b/pkg/cmd/codespace/ssh_test.go index 2c59539da..1740b00f7 100644 --- a/pkg/cmd/codespace/ssh_test.go +++ b/pkg/cmd/codespace/ssh_test.go @@ -278,5 +278,5 @@ func testingSSHApp() *App { } ios, _, _, _ := iostreams.Test() - return NewApp(ios, nil, apiMock, nil) + return NewApp(ios, nil, apiMock, nil, nil) } diff --git a/pkg/cmd/extension/browse/browse.go b/pkg/cmd/extension/browse/browse.go index be605ef9c..247d43c5b 100644 --- a/pkg/cmd/extension/browse/browse.go +++ b/pkg/cmd/extension/browse/browse.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/MakeNowJust/heredoc" "github.com/charmbracelet/glamour" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/config" @@ -25,16 +26,17 @@ import ( const pagingOffset = 24 type ExtBrowseOpts struct { - Cmd *cobra.Command - Browser ibrowser - IO *iostreams.IOStreams - Searcher search.Searcher - Em extensions.ExtensionManager - Client *http.Client - Logger *log.Logger - Cfg config.Config - Rg *readmeGetter - Debug bool + Cmd *cobra.Command + Browser ibrowser + IO *iostreams.IOStreams + Searcher search.Searcher + Em extensions.ExtensionManager + Client *http.Client + Logger *log.Logger + Cfg config.Config + Rg *readmeGetter + Debug bool + SingleColumn bool } type ibrowser interface { @@ -48,7 +50,8 @@ type uiRegistry struct { App *tview.Application Outerflex *tview.Flex List *tview.List - Readme *tview.TextView + Pages *tview.Pages + CmdFlex *tview.Flex } type extEntry struct { @@ -83,25 +86,44 @@ func (e extEntry) Description() string { } type extList struct { - ui uiRegistry - extEntries []extEntry - app *tview.Application - filter string - opts ExtBrowseOpts + ui uiRegistry + extEntries []extEntry + app *tview.Application + filter string + opts ExtBrowseOpts + QueueUpdateDraw func(func()) *tview.Application + WaitGroup wGroup } +type wGroup interface { + Add(int) + Done() + Wait() +} + +type fakeGroup struct{} + +func (w *fakeGroup) Add(int) {} +func (w *fakeGroup) Done() {} +func (w *fakeGroup) Wait() {} + func newExtList(opts ExtBrowseOpts, ui uiRegistry, extEntries []extEntry) *extList { ui.List.SetTitleColor(tcell.ColorWhite) ui.List.SetSelectedTextColor(tcell.ColorBlack) ui.List.SetSelectedBackgroundColor(tcell.ColorWhite) ui.List.SetWrapAround(false) ui.List.SetBorderPadding(1, 1, 1, 1) + ui.List.SetSelectedFunc(func(ix int, _, _ string, _ rune) { + ui.Pages.SwitchToPage("readme") + }) el := &extList{ - ui: ui, - extEntries: extEntries, - app: ui.App, - opts: opts, + ui: ui, + extEntries: extEntries, + app: ui.App, + opts: opts, + QueueUpdateDraw: ui.App.QueueUpdateDraw, + WaitGroup: &fakeGroup{}, } el.Reset() @@ -112,66 +134,97 @@ func (el *extList) createModal() *tview.Modal { m := tview.NewModal() m.SetBackgroundColor(tcell.ColorPurple) m.SetDoneFunc(func(_ int, _ string) { - el.ui.App.SetRoot(el.ui.Outerflex, true) + el.ui.Pages.SwitchToPage("main") el.Refresh() }) return m } -func (el *extList) InstallSelected() { +func (el *extList) toggleSelected(verb string) { ee, ix := el.FindSelected() if ix < 0 { el.opts.Logger.Println("failed to find selected entry") return } - repo, err := ghrepo.FromFullName(ee.FullName) - if err != nil { - el.opts.Logger.Println(fmt.Errorf("failed to install '%s't: %w", ee.FullName, err)) + modal := el.createModal() + + if (ee.Installed && verb == "install") || (!ee.Installed && verb == "remove") { return } - modal := el.createModal() + var action func() error - modal.SetText(fmt.Sprintf("Installing %s...", ee.FullName)) - el.ui.App.SetRoot(modal, true) - // I could eliminate this with a goroutine but it seems to be working fine - el.app.ForceDraw() - err = el.opts.Em.Install(repo, "") - if err != nil { - modal.SetText(fmt.Sprintf("Failed to install %s: %s", ee.FullName, err.Error())) + if !ee.Installed { + modal.SetText(fmt.Sprintf("Installing %s...", ee.FullName)) + action = func() error { + repo, err := ghrepo.FromFullName(ee.FullName) + if err != nil { + el.opts.Logger.Println(fmt.Errorf("failed to install '%s': %w", ee.FullName, err)) + return err + } + err = el.opts.Em.Install(repo, "") + if err != nil { + return fmt.Errorf("failed to install %s: %w", ee.FullName, err) + } + return nil + } } else { - modal.SetText(fmt.Sprintf("Installed %s!", ee.FullName)) - modal.AddButtons([]string{"ok"}) - el.ui.App.SetFocus(modal) + modal.SetText(fmt.Sprintf("Removing %s...", ee.FullName)) + action = func() error { + name := strings.TrimPrefix(ee.Name, "gh-") + err := el.opts.Em.Remove(name) + if err != nil { + return fmt.Errorf("failed to remove %s: %w", ee.FullName, err) + } + return nil + } } - el.toggleInstalled(ix) + el.ui.CmdFlex.Clear() + el.ui.CmdFlex.AddItem(modal, 0, 1, true) + var err error + wg := el.WaitGroup + wg.Add(1) + + go func() { + el.QueueUpdateDraw(func() { + el.ui.Pages.SwitchToPage("command") + wg.Add(1) + wg.Done() + go func() { + el.QueueUpdateDraw(func() { + err = action() + if err != nil { + modal.SetText(err.Error()) + } else { + modalText := fmt.Sprintf("Installed %s!", ee.FullName) + if verb == "remove" { + modalText = fmt.Sprintf("Removed %s!", ee.FullName) + } + modal.SetText(modalText) + modal.AddButtons([]string{"ok"}) + el.app.SetFocus(modal) + } + wg.Done() + }) + }() + }) + }() + + // TODO blocking the app's thread and deadlocking + wg.Wait() + if err == nil { + el.toggleInstalled(ix) + } +} + +func (el *extList) InstallSelected() { + el.toggleSelected("install") } func (el *extList) RemoveSelected() { - ee, ix := el.FindSelected() - if ix < 0 { - el.opts.Logger.Println("failed to find selected extension") - return - } - - modal := el.createModal() - - modal.SetText(fmt.Sprintf("Removing %s...", ee.FullName)) - el.ui.App.SetRoot(modal, true) - // I could eliminate this with a goroutine but it seems to be working fine - el.ui.App.ForceDraw() - - err := el.opts.Em.Remove(strings.TrimPrefix(ee.Name, "gh-")) - if err != nil { - modal.SetText(fmt.Sprintf("Failed to remove %s: %s", ee.FullName, err.Error())) - } else { - modal.SetText(fmt.Sprintf("Removed %s.", ee.FullName)) - modal.AddButtons([]string{"ok"}) - el.ui.App.SetFocus(modal) - } - el.toggleInstalled(ix) + el.toggleSelected("remove") } func (el *extList) toggleInstalled(ix int) { @@ -365,14 +418,19 @@ func ExtBrowse(opts ExtBrowseOpts) error { readme.SetBorder(true).SetBorderColor(tcell.ColorPurple) help := tview.NewTextView() - help.SetText( - "/: filter i/r: install/remove w: open in browser pgup/pgdn: scroll readme q: quit") - help.SetTextAlign(tview.AlignCenter) + help.SetDynamicColors(true) + help.SetText("[::b]?[-:-:-]: help [::b]j/k[-:-:-]: move [::b]i[-:-:-]: install [::b]r[-:-:-]: remove [::b]w[-:-:-]: web [::b]↵[-:-:-]: view readme [::b]q[-:-:-]: quit") + + cmdFlex := tview.NewFlex() + + pages := tview.NewPages() ui := uiRegistry{ App: app, Outerflex: outerFlex, List: list, + Pages: pages, + CmdFlex: cmdFlex, } extList := newExtList(opts, ui, extEntries) @@ -414,7 +472,9 @@ func ExtBrowse(opts ExtBrowseOpts) error { innerFlex.SetDirection(tview.FlexColumn) innerFlex.AddItem(list, 0, 1, true) - innerFlex.AddItem(readme, 0, 1, false) + if !opts.SingleColumn { + innerFlex.AddItem(readme, 0, 1, false) + } outerFlex.SetDirection(tview.FlexRow) outerFlex.AddItem(header, 1, -1, false) @@ -422,7 +482,50 @@ func ExtBrowse(opts ExtBrowseOpts) error { outerFlex.AddItem(innerFlex, 0, 1, true) outerFlex.AddItem(help, 1, -1, false) - app.SetRoot(outerFlex, true) + helpBig := tview.NewTextView() + helpBig.SetDynamicColors(true) + helpBig.SetBorderPadding(0, 0, 2, 0) + helpBig.SetText(heredoc.Doc(` + [::b]Application[-:-:-] + + ?: toggle help + q: quit + + [::b]Navigation[-:-:-] + + ↓, j: scroll list of extensions down by 1 + ↑, k: scroll list of extensions up by 1 + + shift+j, space: scroll list of extensions down by 25 + shift+k, ctrl+space (mac), shift+space (windows): scroll list of extensions up by 25 + + [::b]Extension Management[-:-:-] + + i: install highlighted extension + r: remove highlighted extension + w: open highlighted extension in web browser + + [::b]Filtering[-:-:-] + + /: focus filter + enter: finish filtering and go back to list + escape: clear filter and reset list + + [::b]Readmes[-:-:-] + + enter: open highlighted extension's readme full screen + page down: scroll readme pane down + page up: scroll readme pane up + + (On a mac, page down and page up are fn+down arrow and fn+up arrow) + `)) + + pages.AddPage("main", outerFlex, true, true) + pages.AddPage("help", helpBig, true, false) + pages.AddPage("readme", readme, true, false) + pages.AddPage("command", cmdFlex, true, false) + + app.SetRoot(pages, true) // Force fetching of initial readme by loading it just prior to the first // draw. The callback is removed immediately after draw. @@ -441,7 +544,41 @@ func ExtBrowse(opts ExtBrowseOpts) error { return event } + curPage, _ := pages.GetFrontPage() + + if curPage != "main" { + if curPage == "command" { + return event + } + if event.Rune() == 'q' || event.Key() == tcell.KeyEscape { + pages.SwitchToPage("main") + return nil + } + switch curPage { + case "readme": + switch event.Key() { + case tcell.KeyPgUp: + row, col := readme.GetScrollOffset() + if row > 0 { + readme.ScrollTo(row-2, col) + } + case tcell.KeyPgDn: + row, col := readme.GetScrollOffset() + readme.ScrollTo(row+2, col) + } + case "help": + switch event.Rune() { + case '?': + pages.SwitchToPage("main") + } + } + return nil + } + switch event.Rune() { + case '?': + pages.SwitchToPage("help") + return nil case 'q': app.Stop() case 'k': @@ -491,7 +628,7 @@ func ExtBrowse(opts ExtBrowseOpts) error { filter.SetText("") extList.Reset() case tcell.KeyCtrlSpace: - // The ctrl check works on windows/mac and not windows: + // The ctrl check works on linux/mac and not windows: extList.PageUp() go loadSelectedReadme() case tcell.KeyCtrlJ: @@ -500,25 +637,11 @@ func ExtBrowse(opts ExtBrowseOpts) error { case tcell.KeyCtrlK: extList.PageUp() go loadSelectedReadme() - case tcell.KeyPgUp: - row, col := readme.GetScrollOffset() - if row > 0 { - readme.ScrollTo(row-2, col) - } - return nil - case tcell.KeyPgDn: - row, col := readme.GetScrollOffset() - readme.ScrollTo(row+2, col) - return nil } return event }) - // Without this redirection, the git client inside of the extension manager - // will dump git output to the terminal. - opts.IO.ErrOut = io.Discard - if err := app.Run(); err != nil { return err } diff --git a/pkg/cmd/extension/browse/browse_test.go b/pkg/cmd/extension/browse/browse_test.go index 8cec46607..aacfb37b5 100644 --- a/pkg/cmd/extension/browse/browse_test.go +++ b/pkg/cmd/extension/browse/browse_test.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "net/url" + "sync" "testing" "time" @@ -274,11 +275,15 @@ func Test_extList(t *testing.T) { }, }, } + cmdFlex := tview.NewFlex() app := tview.NewApplication() list := tview.NewList() + pages := tview.NewPages() ui := uiRegistry{ - List: list, - App: app, + List: list, + App: app, + CmdFlex: cmdFlex, + Pages: pages, } extEntries := []extEntry{ { @@ -313,6 +318,13 @@ func Test_extList(t *testing.T) { extList := newExtList(opts, ui, extEntries) + extList.QueueUpdateDraw = func(f func()) *tview.Application { + f() + return app + } + + extList.WaitGroup = &sync.WaitGroup{} + extList.Filter("cool") assert.Equal(t, 1, extList.ui.List.GetItemCount()) @@ -322,6 +334,8 @@ func Test_extList(t *testing.T) { extList.InstallSelected() assert.True(t, extList.extEntries[0].Installed) + // so I think the goroutines are causing a later failure because the toggleInstalled isn't seen. + extList.Refresh() assert.Equal(t, 1, extList.ui.List.GetItemCount()) diff --git a/pkg/cmd/extension/command.go b/pkg/cmd/extension/command.go index 3bfd8cfe8..722163dd2 100644 --- a/pkg/cmd/extension/command.go +++ b/pkg/cmd/extension/command.go @@ -3,6 +3,7 @@ package extension import ( "errors" "fmt" + gio "io" "os" "strings" "time" @@ -24,6 +25,7 @@ import ( func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { m := f.ExtensionManager io := f.IOStreams + gc := f.GitClient prompter := f.Prompter config := f.Config browser := f.Browser @@ -410,33 +412,25 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { }, func() *cobra.Command { var debug bool + var singleColumn bool cmd := &cobra.Command{ Use: "browse", Short: "Enter a UI for browsing, adding, and removing extensions", Long: heredoc.Doc(` This command will take over your terminal and run a fully interactive - interface for browsing, adding, and removing gh extensions. + interface for browsing, adding, and removing gh extensions. A terminal + width greater than 100 columns is recommended. - The extension list is navigated with the arrow keys or with j/k. - Space and control+space (or control + j/k) page the list up and down. - Extension readmes can be scrolled with page up/page down keys - (fn + arrow up/down on a mac keyboard). - - For highlighted extensions, you can press: - - - w to open the extension in your web browser - - i to install the extension - - r to remove the extension - - Press / to focus the filter input. Press enter to scroll the results. - Press Escape to clear the filter and return to the full list. + To learn how to control this interface, press ? after running to see + the help text. Press q to quit. - The output of this command may be difficult to navigate for screen reader - users, users operating at high zoom and other users of assistive technology. It - is also not advised for automation scripts. We advise those users to use the - alternative command: + Running this command with --single-column should make this command + more intelligible for users who rely on assistive technology like screen + readers or high zoom. + + For a more traditional way to discover extensions, see: gh ext search @@ -459,21 +453,25 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { searcher := search.NewSearcher(api.NewCachedHTTPClient(client, time.Hour*24), host) + gc.Stderr = gio.Discard + opts := browse.ExtBrowseOpts{ - Cmd: cmd, - IO: io, - Browser: browser, - Searcher: searcher, - Em: m, - Client: client, - Cfg: cfg, - Debug: debug, + Cmd: cmd, + IO: io, + Browser: browser, + Searcher: searcher, + Em: m, + Client: client, + Cfg: cfg, + Debug: debug, + SingleColumn: singleColumn, } return browse.ExtBrowse(opts) }, } cmd.Flags().BoolVar(&debug, "debug", false, "log to /tmp/extBrowse-*") + cmd.Flags().BoolVarP(&singleColumn, "single-column", "s", false, "Render TUI with only one column of text") return cmd }(), &cobra.Command{ @@ -564,9 +562,18 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { } else { fullName = "gh-" + extName } + + cs := io.ColorScheme() + + commitIcon := cs.SuccessIcon() if err := m.Create(fullName, tmplType); err != nil { - return err + if errors.Is(err, ErrInitialCommitFailed) { + commitIcon = cs.FailureIcon() + } else { + return err + } } + if !io.IsStdoutTTY() { return nil } @@ -577,7 +584,6 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { "- run 'cd %[1]s; gh extension install .; gh %[2]s' to see your new extension in action", fullName, extName) - cs := io.ColorScheme() if tmplType == extensions.GoBinTemplateType { goBinChecks = heredoc.Docf(` %[1]s Downloaded Go dependencies @@ -585,7 +591,7 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { `, cs.SuccessIcon(), fullName) steps = heredoc.Docf(` - run 'cd %[1]s; gh extension install .; gh %[2]s' to see your new extension in action - - use 'go build && gh %[2]s' to see changes in your code as you develop`, fullName, extName) + - run 'go build && gh %[2]s' to see changes in your code as you develop`, fullName, extName) } else if tmplType == extensions.OtherBinTemplateType { steps = heredoc.Docf(` - run 'cd %[1]s; gh extension install .' to install your extension locally @@ -596,17 +602,18 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { out := heredoc.Docf(` %[1]s Created directory %[2]s %[1]s Initialized git repository + %[7]s Made initial commit %[1]s Set up extension scaffolding %[6]s %[2]s is ready for development! %[4]s %[5]s - - commit and use 'gh repo create' to share your extension with others + - run 'gh repo create' to share your extension with others For more information on writing extensions: %[3]s - `, cs.SuccessIcon(), fullName, link, cs.Bold("Next Steps"), steps, goBinChecks) + `, cs.SuccessIcon(), fullName, link, cs.Bold("Next Steps"), steps, goBinChecks, commitIcon) fmt.Fprint(io.Out, out) return nil }, diff --git a/pkg/cmd/extension/command_test.go b/pkg/cmd/extension/command_test.go index d7025ebf1..050b6d6f0 100644 --- a/pkg/cmd/extension/command_test.go +++ b/pkg/cmd/extension/command_test.go @@ -605,13 +605,14 @@ func TestNewCmdExtension(t *testing.T) { wantStdout: heredoc.Doc(` ✓ Created directory gh-test ✓ Initialized git repository + ✓ Made initial commit ✓ Set up extension scaffolding gh-test is ready for development! Next Steps - run 'cd gh-test; gh extension install .; gh test' to see your new extension in action - - commit and use 'gh repo create' to share your extension with others + - run 'gh repo create' to share your extension with others For more information on writing extensions: https://docs.github.com/github-cli/github-cli/creating-github-cli-extensions @@ -634,6 +635,7 @@ func TestNewCmdExtension(t *testing.T) { wantStdout: heredoc.Doc(` ✓ Created directory gh-test ✓ Initialized git repository + ✓ Made initial commit ✓ Set up extension scaffolding ✓ Downloaded Go dependencies ✓ Built gh-test binary @@ -642,8 +644,8 @@ func TestNewCmdExtension(t *testing.T) { Next Steps - run 'cd gh-test; gh extension install .; gh test' to see your new extension in action - - use 'go build && gh test' to see changes in your code as you develop - - commit and use 'gh repo create' to share your extension with others + - run 'go build && gh test' to see changes in your code as you develop + - run 'gh repo create' to share your extension with others For more information on writing extensions: https://docs.github.com/github-cli/github-cli/creating-github-cli-extensions @@ -666,6 +668,7 @@ func TestNewCmdExtension(t *testing.T) { wantStdout: heredoc.Doc(` ✓ Created directory gh-test ✓ Initialized git repository + ✓ Made initial commit ✓ Set up extension scaffolding gh-test is ready for development! @@ -674,7 +677,7 @@ func TestNewCmdExtension(t *testing.T) { - run 'cd gh-test; gh extension install .' to install your extension locally - fill in script/build.sh with your compilation script for automated builds - compile a gh-test binary locally and run 'gh test' to see changes - - commit and use 'gh repo create' to share your extension with others + - run 'gh repo create' to share your extension with others For more information on writing extensions: https://docs.github.com/github-cli/github-cli/creating-github-cli-extensions @@ -697,13 +700,44 @@ func TestNewCmdExtension(t *testing.T) { wantStdout: heredoc.Doc(` ✓ Created directory gh-test ✓ Initialized git repository + ✓ Made initial commit ✓ Set up extension scaffolding gh-test is ready for development! Next Steps - run 'cd gh-test; gh extension install .; gh test' to see your new extension in action - - commit and use 'gh repo create' to share your extension with others + - run 'gh repo create' to share your extension with others + + For more information on writing extensions: + https://docs.github.com/github-cli/github-cli/creating-github-cli-extensions + `), + }, + { + name: "create extension tty with argument commit fails", + args: []string{"create", "test"}, + managerStubs: func(em *extensions.ExtensionManagerMock) func(*testing.T) { + em.CreateFunc = func(name string, tmplType extensions.ExtTemplateType) error { + return ErrInitialCommitFailed + } + return func(t *testing.T) { + calls := em.CreateCalls() + assert.Equal(t, 1, len(calls)) + assert.Equal(t, "gh-test", calls[0].Name) + } + }, + isTTY: true, + wantStdout: heredoc.Doc(` + ✓ Created directory gh-test + ✓ Initialized git repository + X Made initial commit + ✓ Set up extension scaffolding + + gh-test is ready for development! + + Next Steps + - run 'cd gh-test; gh extension install .; gh test' to see your new extension in action + - run 'gh repo create' to share your extension with others For more information on writing extensions: https://docs.github.com/github-cli/github-cli/creating-github-cli-extensions diff --git a/pkg/cmd/extension/manager.go b/pkg/cmd/extension/manager.go index ac460bab9..4f0e6a9fb 100644 --- a/pkg/cmd/extension/manager.go +++ b/pkg/cmd/extension/manager.go @@ -27,6 +27,9 @@ import ( "gopkg.in/yaml.v3" ) +// ErrInitialCommitFailed indicates the initial commit when making a new extension failed. +var ErrInitialCommitFailed = errors.New("initial commit failed") + type Manager struct { dataDir func() string lookPath func(string) (string, error) @@ -347,7 +350,7 @@ func (m *Manager) Install(repo ghrepo.Interface, target string) error { return errors.New("extension is not installable: missing executable") } - return m.installGit(repo, target, m.io.Out, m.io.ErrOut) + return m.installGit(repo, target) } func (m *Manager) installBin(repo ghrepo.Interface, target string) error { @@ -450,7 +453,7 @@ func (m *Manager) installBin(repo ghrepo.Interface, target string) error { return nil } -func (m *Manager) installGit(repo ghrepo.Interface, target string, stdout, stderr io.Writer) error { +func (m *Manager) installGit(repo ghrepo.Interface, target string) error { protocol, _ := m.config.GetOrDefault(repo.RepoHost(), "git_protocol") cloneURL := ghrepo.FormatRemoteURL(repo, protocol) @@ -654,8 +657,15 @@ func (m *Manager) Create(name string, tmplType extensions.ExtTemplateType) error } scopedClient := m.gitClient.ForRepo(name) - _, err := scopedClient.CommandOutput([]string{"add", name, "--chmod=+x"}) - return err + if _, err := scopedClient.CommandOutput([]string{"add", name, "--chmod=+x"}); err != nil { + return err + } + + if _, err := scopedClient.CommandOutput([]string{"commit", "-m", "initial commit"}); err != nil { + return ErrInitialCommitFailed + } + + return nil } func (m *Manager) otherBinScaffolding(name string) error { @@ -672,8 +682,15 @@ func (m *Manager) otherBinScaffolding(name string) error { return err } - _, err := scopedClient.CommandOutput([]string{"add", "."}) - return err + if _, err := scopedClient.CommandOutput([]string{"add", "."}); err != nil { + return err + } + + if _, err := scopedClient.CommandOutput([]string{"commit", "-m", "initial commit"}); err != nil { + return ErrInitialCommitFailed + } + + return nil } func (m *Manager) goBinScaffolding(name string) error { @@ -718,8 +735,15 @@ func (m *Manager) goBinScaffolding(name string) error { } scopedClient := m.gitClient.ForRepo(name) - _, err = scopedClient.CommandOutput([]string{"add", "."}) - return err + if _, err := scopedClient.CommandOutput([]string{"add", "."}); err != nil { + return err + } + + if _, err := scopedClient.CommandOutput([]string{"commit", "-m", "initial commit"}); err != nil { + return ErrInitialCommitFailed + } + + return nil } func isSymlink(m os.FileMode) bool { diff --git a/pkg/cmd/extension/manager_test.go b/pkg/cmd/extension/manager_test.go index 960eee45e..4f8b3d75d 100644 --- a/pkg/cmd/extension/manager_test.go +++ b/pkg/cmd/extension/manager_test.go @@ -1036,6 +1036,7 @@ func TestManager_Create(t *testing.T) { gc.On("ForRepo", "gh-test").Return(gcOne).Once() gc.On("CommandOutput", []string{"init", "--quiet", "gh-test"}).Return("", nil).Once() gcOne.On("CommandOutput", []string{"add", "gh-test", "--chmod=+x"}).Return("", nil).Once() + gcOne.On("CommandOutput", []string{"commit", "-m", "initial commit"}).Return("", nil).Once() m := newTestManager(".", nil, gc, ios) @@ -1068,6 +1069,7 @@ func TestManager_Create_go_binary(t *testing.T) { gc.On("ForRepo", "gh-test").Return(gcOne).Once() gc.On("CommandOutput", []string{"init", "--quiet", "gh-test"}).Return("", nil).Once() gcOne.On("CommandOutput", []string{"add", "."}).Return("", nil).Once() + gcOne.On("CommandOutput", []string{"commit", "-m", "initial commit"}).Return("", nil).Once() m := newTestManager(".", &http.Client{Transport: ®}, gc, ios) @@ -1111,6 +1113,7 @@ func TestManager_Create_other_binary(t *testing.T) { gc.On("CommandOutput", []string{"init", "--quiet", "gh-test"}).Return("", nil).Once() gcOne.On("CommandOutput", []string{"add", filepath.Join("script", "build.sh"), "--chmod=+x"}).Return("", nil).Once() gcOne.On("CommandOutput", []string{"add", "."}).Return("", nil).Once() + gcOne.On("CommandOutput", []string{"commit", "-m", "initial commit"}).Return("", nil).Once() m := newTestManager(".", nil, gc, ios) diff --git a/pkg/cmd/gpg-key/delete/delete.go b/pkg/cmd/gpg-key/delete/delete.go index bb5277a38..b8569f72d 100644 --- a/pkg/cmd/gpg-key/delete/delete.go +++ b/pkg/cmd/gpg-key/delete/delete.go @@ -38,7 +38,7 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co opts.KeyID = args[0] if !opts.IO.CanPrompt() && !opts.Confirmed { - return cmdutil.FlagErrorf("--confirm required when not running interactively") + return cmdutil.FlagErrorf("--yes required when not running interactively") } if runF != nil { @@ -48,7 +48,9 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co }, } - cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt") + cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt") return cmd } diff --git a/pkg/cmd/gpg-key/delete/delete_test.go b/pkg/cmd/gpg-key/delete/delete_test.go index 2835f9cb2..a7f0fda67 100644 --- a/pkg/cmd/gpg-key/delete/delete_test.go +++ b/pkg/cmd/gpg-key/delete/delete_test.go @@ -32,7 +32,7 @@ func TestNewCmdDelete(t *testing.T) { { name: "confirm flag tty", tty: true, - input: "ABC123 --confirm", + input: "ABC123 --yes", output: DeleteOptions{KeyID: "ABC123", Confirmed: true}, }, { @@ -45,11 +45,11 @@ func TestNewCmdDelete(t *testing.T) { name: "no tty", input: "ABC123", wantErr: true, - wantErrMsg: "--confirm required when not running interactively", + wantErrMsg: "--yes required when not running interactively", }, { name: "confirm flag no tty", - input: "ABC123 --confirm", + input: "ABC123 --yes", output: DeleteOptions{KeyID: "ABC123", Confirmed: true}, }, { diff --git a/pkg/cmd/issue/create/create.go b/pkg/cmd/issue/create/create.go index 081e55b20..9e14760f5 100644 --- a/pkg/cmd/issue/create/create.go +++ b/pkg/cmd/issue/create/create.go @@ -55,6 +55,12 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co cmd := &cobra.Command{ Use: "create", Short: "Create a new issue", + Long: heredoc.Doc(` + Create an issue on GitHub. + + Adding an issue to projects requires authorization with the "project" scope. + To authorize, run "gh auth refresh -s project". + `), Example: heredoc.Doc(` $ gh issue create --title "I found a bug" --body "Nothing works" $ gh issue create --label "bug,help wanted" diff --git a/pkg/cmd/issue/create/create_test.go b/pkg/cmd/issue/create/create_test.go index 9bfaea8b5..de98d52b4 100644 --- a/pkg/cmd/issue/create/create_test.go +++ b/pkg/cmd/issue/create/create_test.go @@ -217,6 +217,24 @@ func Test_createRun(t *testing.T) { ], "pageInfo": { "hasNextPage": false } } } } }`)) + r.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "CleanupV2", "id": "CLEANUPV2ID", "resourcePath": "/OWNER/REPO/projects/2" } + ], + "pageInfo": { "hasNextPage": false } + } } } }`)) + r.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [ + { "title": "Triage", "id": "TRIAGEID", "resourcePath": "/orgs/ORG/projects/2" } + ], + "pageInfo": { "hasNextPage": false } + } } } }`)) }, wantsBrowse: "https://github.com/OWNER/REPO/issues/new?body=&projects=OWNER%2FREPO%2F1", wantsStderr: "Opening github.com/OWNER/REPO/issues/new in your browser.\n", @@ -612,6 +630,22 @@ func TestIssueCreate_metadata(t *testing.T) { }] } `)) + http.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + http.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) http.Register( httpmock.GraphQL(`mutation IssueCreate\b`), httpmock.GraphQLMutation(` @@ -625,12 +659,9 @@ func TestIssueCreate_metadata(t *testing.T) { assert.Equal(t, []interface{}{"BUGID", "TODOID"}, inputs["labelIds"]) assert.Equal(t, []interface{}{"ROADMAPID"}, inputs["projectIds"]) assert.Equal(t, "BIGONEID", inputs["milestoneId"]) - if v, ok := inputs["userIds"]; ok { - t.Errorf("did not expect userIds: %v", v) - } - if v, ok := inputs["teamIds"]; ok { - t.Errorf("did not expect teamIds: %v", v) - } + assert.NotContains(t, inputs, "userIds") + assert.NotContains(t, inputs, "teamIds") + assert.NotContains(t, inputs, "projectV2Ids") })) output, err := runCommand(http, true, `-t TITLE -b BODY -a monalisa -l bug -l todo -p roadmap -m 'big one.oh'`, nil) @@ -712,3 +743,81 @@ func TestIssueCreate_AtMeAssignee(t *testing.T) { assert.Equal(t, "https://github.com/OWNER/REPO/issues/12\n", output.String()) } + +func TestIssueCreate_projectsV2(t *testing.T) { + http := &httpmock.Registry{} + defer http.Verify(t) + + http.StubRepoInfoResponse("OWNER", "REPO", "main") + http.Register( + httpmock.GraphQL(`query RepositoryProjectList\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projects": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + http.Register( + httpmock.GraphQL(`query OrganizationProjectList\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projects": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + http.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "CleanupV2", "id": "CLEANUPV2ID" }, + { "title": "RoadmapV2", "id": "ROADMAPV2ID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + http.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [ + { "title": "TriageV2", "id": "TriageV2ID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + http.Register( + httpmock.GraphQL(`mutation IssueCreate\b`), + httpmock.GraphQLMutation(` + { "data": { "createIssue": { "issue": { + "id": "Issue#1", + "URL": "https://github.com/OWNER/REPO/issues/12" + } } } } + `, func(inputs map[string]interface{}) { + assert.Equal(t, "TITLE", inputs["title"]) + assert.Equal(t, "BODY", inputs["body"]) + assert.Nil(t, inputs["projectIds"]) + assert.NotContains(t, inputs, "projectV2Ids") + })) + http.Register( + httpmock.GraphQL(`mutation UpdateProjectV2Items\b`), + httpmock.GraphQLQuery(` + { "data": { "add_000": { "item": { + "id": "1" + } } } } + `, func(mutations string, inputs map[string]interface{}) { + variables, err := json.Marshal(inputs) + assert.NoError(t, err) + expectedMutations := "mutation UpdateProjectV2Items($input_000: AddProjectV2ItemByIdInput!) {add_000: addProjectV2ItemById(input: $input_000) { item { id } }}" + expectedVariables := `{"input_000":{"contentId":"Issue#1","projectId":"ROADMAPV2ID"}}` + assert.Equal(t, expectedMutations, mutations) + assert.Equal(t, expectedVariables, string(variables)) + })) + + output, err := runCommand(http, true, `-t TITLE -b BODY -p roadmapv2`, nil) + if err != nil { + t.Errorf("error running command `issue create`: %v", err) + } + + assert.Equal(t, "https://github.com/OWNER/REPO/issues/12\n", output.String()) +} diff --git a/pkg/cmd/issue/delete/delete.go b/pkg/cmd/issue/delete/delete.go index e9297c65a..f0c68de42 100644 --- a/pkg/cmd/issue/delete/delete.go +++ b/pkg/cmd/issue/delete/delete.go @@ -49,11 +49,14 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co if runF != nil { return runF(opts) } + return deleteRun(opts) }, } cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "confirm deletion without prompting") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVar(&opts.Confirmed, "yes", false, "confirm deletion without prompting") return cmd } diff --git a/pkg/cmd/issue/edit/edit.go b/pkg/cmd/issue/edit/edit.go index 3f7cb3b84..abea9eaae 100644 --- a/pkg/cmd/issue/edit/edit.go +++ b/pkg/cmd/issue/edit/edit.go @@ -45,6 +45,12 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman cmd := &cobra.Command{ Use: "edit { | }", Short: "Edit an issue", + Long: heredoc.Doc(` + Edit an issue. + + Editing an issue's projects requires authorization with the "project" scope. + To authorize, run "gh auth refresh -s project". + `), Example: heredoc.Doc(` $ gh issue edit 23 --title "I found a bug" --body "Nothing works" $ gh issue edit 23 --add-label "bug,help wanted" --remove-label "core" @@ -145,6 +151,7 @@ func editRun(opts *EditOptions) error { } if opts.Interactive || editable.Projects.Edited { lookupFields = append(lookupFields, "projectCards") + lookupFields = append(lookupFields, "projectItems") } if opts.Interactive || editable.Milestone.Edited { lookupFields = append(lookupFields, "milestone") @@ -159,7 +166,12 @@ func editRun(opts *EditOptions) error { editable.Body.Default = issue.Body editable.Assignees.Default = issue.Assignees.Logins() editable.Labels.Default = issue.Labels.Names() - editable.Projects.Default = issue.ProjectCards.ProjectNames() + editable.Projects.Default = append(issue.ProjectCards.ProjectNames(), issue.ProjectItems.ProjectTitles()...) + projectItems := map[string]string{} + for _, n := range issue.ProjectItems.Nodes { + projectItems[n.Project.ID] = n.ID + } + editable.Projects.ProjectItems = projectItems if issue.Milestone != nil { editable.Milestone.Default = issue.Milestone.Title } diff --git a/pkg/cmd/issue/edit/edit_test.go b/pkg/cmd/issue/edit/edit_test.go index a43b3ae19..891fe2414 100644 --- a/pkg/cmd/issue/edit/edit_test.go +++ b/pkg/cmd/issue/edit/edit_test.go @@ -165,9 +165,11 @@ func TestNewCmdEdit(t *testing.T) { output: EditOptions{ SelectorArg: "23", Editable: prShared.Editable{ - Projects: prShared.EditableSlice{ - Add: []string{"Cleanup", "Roadmap"}, - Edited: true, + Projects: prShared.EditableProjects{ + EditableSlice: prShared.EditableSlice{ + Add: []string{"Cleanup", "Roadmap"}, + Edited: true, + }, }, }, }, @@ -179,9 +181,11 @@ func TestNewCmdEdit(t *testing.T) { output: EditOptions{ SelectorArg: "23", Editable: prShared.Editable{ - Projects: prShared.EditableSlice{ - Remove: []string{"Cleanup", "Roadmap"}, - Edited: true, + Projects: prShared.EditableProjects{ + EditableSlice: prShared.EditableSlice{ + Remove: []string{"Cleanup", "Roadmap"}, + Edited: true, + }, }, }, }, @@ -278,10 +282,12 @@ func Test_editRun(t *testing.T) { Remove: []string{"docs"}, Edited: true, }, - Projects: prShared.EditableSlice{ - Add: []string{"Cleanup", "Roadmap"}, - Remove: []string{"Features"}, - Edited: true, + Projects: prShared.EditableProjects{ + EditableSlice: prShared.EditableSlice{ + Add: []string{"Cleanup", "CleanupV2"}, + Remove: []string{"Roadmap", "RoadmapV2"}, + Edited: true, + }, }, Milestone: prShared.EditableString{ Value: "GA", @@ -297,9 +303,11 @@ func Test_editRun(t *testing.T) { }, httpStubs: func(t *testing.T, reg *httpmock.Registry) { mockIssueGet(t, reg) + mockIssueProjectItemsGet(t, reg) mockRepoMetadata(t, reg) mockIssueUpdate(t, reg) mockIssueUpdateLabels(t, reg) + mockProjectV2ItemUpdate(t, reg) }, stdout: "https://github.com/OWNER/REPO/issue/123\n", }, @@ -322,7 +330,9 @@ func Test_editRun(t *testing.T) { eo.Body.Value = "new body" eo.Assignees.Value = []string{"monalisa", "hubot"} eo.Labels.Value = []string{"feature", "TODO", "bug"} - eo.Projects.Value = []string{"Cleanup", "Roadmap"} + eo.Labels.Add = []string{"feature", "TODO", "bug"} + eo.Labels.Remove = []string{"docs"} + eo.Projects.Value = []string{"Cleanup", "CleanupV2"} eo.Milestone.Value = "GA" return nil }, @@ -331,8 +341,11 @@ func Test_editRun(t *testing.T) { }, httpStubs: func(t *testing.T, reg *httpmock.Registry) { mockIssueGet(t, reg) + mockIssueProjectItemsGet(t, reg) mockRepoMetadata(t, reg) mockIssueUpdate(t, reg) + mockIssueUpdateLabels(t, reg) + mockProjectV2ItemUpdate(t, reg) }, stdout: "https://github.com/OWNER/REPO/issue/123\n", }, @@ -369,7 +382,31 @@ func mockIssueGet(_ *testing.T, reg *httpmock.Registry) { httpmock.StringResponse(` { "data": { "repository": { "hasIssuesEnabled": true, "issue": { "number": 123, - "url": "https://github.com/OWNER/REPO/issue/123" + "url": "https://github.com/OWNER/REPO/issue/123", + "labels": { + "nodes": [ + { "id": "DOCSID", "name": "docs" } + ], "totalCount": 1 + }, + "projectCards": { + "nodes": [ + { "project": { "name": "Roadmap" } } + ], "totalCount": 1 + } + } } } }`), + ) +} + +func mockIssueProjectItemsGet(_ *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`query IssueProjectItems\b`), + httpmock.StringResponse(` + { "data": { "repository": { "issue": { + "projectItems": { + "nodes": [ + { "id": "ITEMID", "project": { "title": "RoadmapV2" } } + ] + } } } } }`), ) } @@ -431,6 +468,27 @@ func mockRepoMetadata(_ *testing.T, reg *httpmock.Registry) { "pageInfo": { "hasNextPage": false } } } } } `)) + reg.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "CleanupV2", "id": "CLEANUPV2ID" }, + { "title": "RoadmapV2", "id": "ROADMAPV2ID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [ + { "title": "TriageV2", "id": "TRIAGEV2ID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) } func mockIssueUpdate(t *testing.T, reg *httpmock.Registry) { @@ -456,3 +514,12 @@ func mockIssueUpdateLabels(t *testing.T, reg *httpmock.Registry) { func(inputs map[string]interface{}) {}), ) } + +func mockProjectV2ItemUpdate(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation UpdateProjectV2Items\b`), + httpmock.GraphQLMutation(` + { "data": { "add_000": { "item": { "id": "1" } }, "delete_001": { "item": { "id": "2" } } } }`, + func(inputs map[string]interface{}) {}), + ) +} diff --git a/pkg/cmd/issue/shared/lookup.go b/pkg/cmd/issue/shared/lookup.go index 0b766292b..3c1d2fdf4 100644 --- a/pkg/cmd/issue/shared/lookup.go +++ b/pkg/cmd/issue/shared/lookup.go @@ -97,6 +97,13 @@ func findIssueOrPR(httpClient *http.Client, repo ghrepo.Interface, number int, f fieldSet.Remove("stateReason") } } + + var getProjectItems bool + if fieldSet.Contains("projectItems") { + getProjectItems = true + fieldSet.Remove("projectItems") + } + fields = fieldSet.ToSlice() type response struct { @@ -151,5 +158,13 @@ func findIssueOrPR(httpClient *http.Client, repo ghrepo.Interface, number int, f return nil, errors.New("issue was not found but GraphQL reported no error") } + if getProjectItems { + apiClient := api.NewClientFromHTTP(httpClient) + err := api.ProjectsV2ItemsForIssue(apiClient, repo, resp.Repository.Issue) + if err != nil && !api.ProjectsV2IgnorableError(err) { + return nil, err + } + } + return resp.Repository.Issue, nil } diff --git a/pkg/cmd/label/delete.go b/pkg/cmd/label/delete.go index 8788b7215..c9d8f4cae 100644 --- a/pkg/cmd/label/delete.go +++ b/pkg/cmd/label/delete.go @@ -42,7 +42,7 @@ func newCmdDelete(f *cmdutil.Factory, runF func(*deleteOptions) error) *cobra.Co opts.Name = args[0] if !opts.IO.CanPrompt() && !opts.Confirmed { - return cmdutil.FlagErrorf("--confirm required when not running interactively") + return cmdutil.FlagErrorf("--yes required when not running interactively") } if runF != nil { @@ -53,6 +53,8 @@ func newCmdDelete(f *cmdutil.Factory, runF func(*deleteOptions) error) *cobra.Co } cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Confirm deletion without prompting") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVar(&opts.Confirmed, "yes", false, "Confirm deletion without prompting") return cmd } diff --git a/pkg/cmd/label/delete_test.go b/pkg/cmd/label/delete_test.go index 3bee0d987..f43b8d5bd 100644 --- a/pkg/cmd/label/delete_test.go +++ b/pkg/cmd/label/delete_test.go @@ -37,14 +37,14 @@ func TestNewCmdDelete(t *testing.T) { }, { name: "confirm argument", - input: "test --confirm", + input: "test --yes", output: deleteOptions{Name: "test", Confirmed: true}, }, { name: "confirm no tty", input: "test", wantErr: true, - wantErrMsg: "--confirm required when not running interactively", + wantErrMsg: "--yes required when not running interactively", }, } diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 3248e485d..ca10bafea 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -110,6 +110,9 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co By default, users with write access to the base repository can push new commits to the head branch of the pull request. Disable this with %[1]s--no-maintainer-edit%[1]s. + + Adding a pull request to projects requires authorization with the "project" scope. + To authorize, run "gh auth refresh -s project". `, "`"), Example: heredoc.Doc(` $ gh pr create --title "The bug is fixed" --body "Everything works again" @@ -179,6 +182,14 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co fl.Bool("no-maintainer-edit", false, "Disable maintainer's ability to modify pull request") fl.StringVar(&opts.RecoverFile, "recover", "", "Recover input from a failed run of create") + _ = cmd.RegisterFlagCompletionFunc("reviewer", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + results, err := requestableReviewersForCompletion(opts) + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + return results, cobra.ShellCompDirectiveNoFileComp + }) + return cmd } @@ -469,16 +480,10 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { } client := api.NewClientFromHTTP(httpClient) - // TODO: consider obtaining remotes from GitClient instead - remotes, err := opts.Remotes() + remotes, err := getRemotes(opts) if err != nil { - // When a repo override value is given, ignore errors when fetching git remotes - // to support using this command outside of git repos. - if opts.RepoOverride == "" { - return nil, err - } + return nil, err } - repoContext, err := ghContext.ResolveRemotesToRepos(remotes, client, opts.RepoOverride) if err != nil { return nil, err @@ -625,6 +630,19 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { } +func getRemotes(opts *CreateOptions) (ghContext.Remotes, error) { + // TODO: consider obtaining remotes from GitClient instead + remotes, err := opts.Remotes() + if err != nil { + // When a repo override value is given, ignore errors when fetching git remotes + // to support using this command outside of git repos. + if opts.RepoOverride == "" { + return nil, err + } + } + return remotes, nil +} + func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataState) error { client := ctx.Client @@ -680,7 +698,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { // one by forking the base repository if headRepo == nil && ctx.IsPushEnabled { opts.IO.StartProgressIndicator() - headRepo, err = api.ForkRepo(client, ctx.BaseRepo, "", "") + headRepo, err = api.ForkRepo(client, ctx.BaseRepo, "", "", false) opts.IO.StopProgressIndicator() if err != nil { return fmt.Errorf("error forking repo: %w", err) @@ -779,4 +797,26 @@ func humanize(s string) string { return strings.Map(h, s) } +func requestableReviewersForCompletion(opts *CreateOptions) ([]string, error) { + httpClient, err := opts.HttpClient() + if err != nil { + return nil, err + } + + remotes, err := getRemotes(opts) + if err != nil { + return nil, err + } + repoContext, err := ghContext.ResolveRemotesToRepos(remotes, api.NewClientFromHTTP(httpClient), opts.RepoOverride) + if err != nil { + return nil, err + } + baseRepo, err := repoContext.BaseRepo(opts.IO) + if err != nil { + return nil, err + } + + return shared.RequestableReviewersForCompletion(httpClient, baseRepo) +} + var gitPushRegexp = regexp.MustCompile("^remote: (Create a pull request.*by visiting|[[:space:]]*https://.*/pull/new/).*\n?$") diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 26e2a71d0..c9d1d70e1 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -280,6 +280,94 @@ func Test_createRun(t *testing.T) { expectedOut: "https://github.com/OWNER/REPO/pull/12\n", expectedErrOut: "\nCreating pull request for feature into master in OWNER/REPO\n\n", }, + { + name: "project v2", + tty: true, + setup: func(opts *CreateOptions, t *testing.T) func() { + opts.TitleProvided = true + opts.BodyProvided = true + opts.Title = "my title" + opts.Body = "my body" + opts.Projects = []string{"RoadmapV2"} + return func() {} + }, + httpStubs: func(reg *httpmock.Registry, t *testing.T) { + reg.StubRepoResponse("OWNER", "REPO") + reg.Register( + httpmock.GraphQL(`query UserCurrent\b`), + httpmock.StringResponse(`{"data": {"viewer": {"login": "OWNER"} } }`)) + reg.Register( + httpmock.GraphQL(`query RepositoryProjectList\b`), + httpmock.StringResponse(`{ "data": { "repository": { "projects": { "nodes": [], "pageInfo": { "hasNextPage": false } } } } }`)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectList\b`), + httpmock.StringResponse(`{ "data": { "organization": { "projects": { "nodes": [], "pageInfo": { "hasNextPage": false } } } } }`)) + reg.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "CleanupV2", "id": "CLEANUPV2ID" }, + { "title": "RoadmapV2", "id": "ROADMAPV2ID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`mutation PullRequestCreate\b`), + httpmock.GraphQLMutation(` + { "data": { "createPullRequest": { "pullRequest": { + "id": "PullRequest#1", + "URL": "https://github.com/OWNER/REPO/pull/12" + } } } } + `, func(input map[string]interface{}) { + assert.Equal(t, "REPOID", input["repositoryId"].(string)) + assert.Equal(t, "my title", input["title"].(string)) + assert.Equal(t, "my body", input["body"].(string)) + assert.Equal(t, "master", input["baseRefName"].(string)) + assert.Equal(t, "feature", input["headRefName"].(string)) + assert.Equal(t, false, input["draft"].(bool)) + })) + reg.Register( + httpmock.GraphQL(`mutation UpdateProjectV2Items\b`), + httpmock.GraphQLQuery(` + { "data": { "add_000": { "item": { + "id": "1" + } } } } + `, func(mutations string, inputs map[string]interface{}) { + variables, err := json.Marshal(inputs) + assert.NoError(t, err) + expectedMutations := "mutation UpdateProjectV2Items($input_000: AddProjectV2ItemByIdInput!) {add_000: addProjectV2ItemById(input: $input_000) { item { id } }}" + expectedVariables := `{"input_000":{"contentId":"PullRequest#1","projectId":"ROADMAPV2ID"}}` + assert.Equal(t, expectedMutations, mutations) + assert.Equal(t, expectedVariables, string(variables)) + })) + }, + cmdStubs: func(cs *run.CommandStubber) { + cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") + cs.Register(`git push --set-upstream origin HEAD:feature`, 0, "") + }, + promptStubs: func(pm *prompter.PrompterMock) { + pm.SelectFunc = func(p, _ string, opts []string) (int, error) { + if p == "Where should we push the 'feature' branch?" { + return 0, nil + } else { + return -1, prompter.NoSuchPromptErr(p) + } + } + }, + expectedOut: "https://github.com/OWNER/REPO/pull/12\n", + expectedErrOut: "\nCreating pull request for feature into master in OWNER/REPO\n\n", + }, { name: "no maintainer modify", tty: true, @@ -363,7 +451,7 @@ func Test_createRun(t *testing.T) { cmdStubs: func(cs *run.CommandStubber) { cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "") cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") - cs.Register(`git remote add -f fork https://github.com/monalisa/REPO.git`, 0, "") + cs.Register(`git remote add fork https://github.com/monalisa/REPO.git`, 0, "") cs.Register(`git push --set-upstream fork HEAD:feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -575,6 +663,17 @@ func Test_createRun(t *testing.T) { "pageInfo": { "hasNextPage": false } } } } } `)) + reg.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "CleanupV2", "id": "CLEANUPV2ID" }, + { "title": "RoadmapV2", "id": "ROADMAPV2ID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) reg.Register( httpmock.GraphQL(`query OrganizationProjectList\b`), httpmock.StringResponse(` @@ -583,6 +682,14 @@ func Test_createRun(t *testing.T) { "pageInfo": { "hasNextPage": false } } } } } `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) reg.Register( httpmock.GraphQL(`mutation PullRequestCreate\b`), httpmock.GraphQLMutation(` @@ -696,6 +803,16 @@ func Test_createRun(t *testing.T) { ], "pageInfo": { "hasNextPage": false } } } } } + `)) + reg.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "CleanupV2", "id": "CLEANUPV2ID", "resourcePath": "/OWNER/REPO/projects/2" } + ], + "pageInfo": { "hasNextPage": false } + } } } } `)) reg.Register( httpmock.GraphQL(`query OrganizationProjectList\b`), @@ -706,6 +823,16 @@ func Test_createRun(t *testing.T) { ], "pageInfo": { "hasNextPage": false } } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [ + { "title": "TriageV2", "id": "TRIAGEV2ID", "resourcePath": "/orgs/ORG/projects/2" } + ], + "pageInfo": { "hasNextPage": false } + } } } } `)) }, cmdStubs: func(cs *run.CommandStubber) { diff --git a/pkg/cmd/pr/edit/edit.go b/pkg/cmd/pr/edit/edit.go index bad2cfa8f..216d0d6d2 100644 --- a/pkg/cmd/pr/edit/edit.go +++ b/pkg/cmd/pr/edit/edit.go @@ -50,6 +50,9 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman Without an argument, the pull request that belongs to the current branch is selected. + + Editing a pull request's projects requires authorization with the "project" scope. + To authorize, run "gh auth refresh -s project". `), Example: heredoc.Doc(` $ gh pr edit 23 --title "I found a bug" --body "Nothing works" @@ -145,13 +148,31 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman cmd.Flags().StringSliceVar(&opts.Editable.Projects.Remove, "remove-project", nil, "Remove the pull request from projects by `name`") cmd.Flags().StringVarP(&opts.Editable.Milestone.Value, "milestone", "m", "", "Edit the milestone the pull request belongs to by `name`") + for _, flagName := range []string{"add-reviewer", "remove-reviewer"} { + _ = cmd.RegisterFlagCompletionFunc(flagName, func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + baseRepo, err := f.BaseRepo() + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + httpClient, err := f.HttpClient() + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + results, err := shared.RequestableReviewersForCompletion(httpClient, baseRepo) + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + return results, cobra.ShellCompDirectiveNoFileComp + }) + } + return cmd } func editRun(opts *EditOptions) error { findOptions := shared.FindOptions{ Selector: opts.SelectorArg, - Fields: []string{"id", "url", "title", "body", "baseRefName", "reviewRequests", "assignees", "labels", "projectCards", "milestone"}, + Fields: []string{"id", "url", "title", "body", "baseRefName", "reviewRequests", "assignees", "labels", "projectCards", "projectItems", "milestone"}, } pr, repo, err := opts.Finder.Find(findOptions) if err != nil { @@ -166,7 +187,12 @@ func editRun(opts *EditOptions) error { editable.Reviewers.Default = pr.ReviewRequests.Logins() editable.Assignees.Default = pr.Assignees.Logins() editable.Labels.Default = pr.Labels.Names() - editable.Projects.Default = pr.ProjectCards.ProjectNames() + editable.Projects.Default = append(pr.ProjectCards.ProjectNames(), pr.ProjectItems.ProjectTitles()...) + projectItems := map[string]string{} + for _, n := range pr.ProjectItems.Nodes { + projectItems[n.Project.ID] = n.ID + } + editable.Projects.ProjectItems = projectItems if pr.Milestone != nil { editable.Milestone.Default = pr.Milestone.Title } diff --git a/pkg/cmd/pr/edit/edit_test.go b/pkg/cmd/pr/edit/edit_test.go index fb7c00932..f81e08cd1 100644 --- a/pkg/cmd/pr/edit/edit_test.go +++ b/pkg/cmd/pr/edit/edit_test.go @@ -216,9 +216,11 @@ func TestNewCmdEdit(t *testing.T) { output: EditOptions{ SelectorArg: "23", Editable: shared.Editable{ - Projects: shared.EditableSlice{ - Add: []string{"Cleanup", "Roadmap"}, - Edited: true, + Projects: shared.EditableProjects{ + EditableSlice: shared.EditableSlice{ + Add: []string{"Cleanup", "Roadmap"}, + Edited: true, + }, }, }, }, @@ -230,9 +232,11 @@ func TestNewCmdEdit(t *testing.T) { output: EditOptions{ SelectorArg: "23", Editable: shared.Editable{ - Projects: shared.EditableSlice{ - Remove: []string{"Cleanup", "Roadmap"}, - Edited: true, + Projects: shared.EditableProjects{ + EditableSlice: shared.EditableSlice{ + Remove: []string{"Cleanup", "Roadmap"}, + Edited: true, + }, }, }, }, @@ -341,10 +345,12 @@ func Test_editRun(t *testing.T) { Remove: []string{"docs"}, Edited: true, }, - Projects: shared.EditableSlice{ - Add: []string{"Cleanup", "Roadmap"}, - Remove: []string{"Features"}, - Edited: true, + Projects: shared.EditableProjects{ + EditableSlice: shared.EditableSlice{ + Add: []string{"Cleanup", "CleanupV2"}, + Remove: []string{"Roadmap", "RoadmapV2"}, + Edited: true, + }, }, Milestone: shared.EditableString{ Value: "GA", @@ -358,6 +364,7 @@ func Test_editRun(t *testing.T) { mockPullRequestUpdate(t, reg) mockPullRequestReviewersUpdate(t, reg) mockPullRequestUpdateLabels(t, reg) + mockProjectV2ItemUpdate(t, reg) }, stdout: "https://github.com/OWNER/REPO/pull/123\n", }, @@ -392,10 +399,12 @@ func Test_editRun(t *testing.T) { Remove: []string{"docs"}, Edited: true, }, - Projects: shared.EditableSlice{ - Value: []string{"Cleanup", "Roadmap"}, - Remove: []string{"Features"}, - Edited: true, + Projects: shared.EditableProjects{ + EditableSlice: shared.EditableSlice{ + Add: []string{"Cleanup", "CleanupV2"}, + Remove: []string{"Roadmap", "RoadmapV2"}, + Edited: true, + }, }, Milestone: shared.EditableString{ Value: "GA", @@ -408,6 +417,7 @@ func Test_editRun(t *testing.T) { mockRepoMetadata(t, reg, true) mockPullRequestUpdate(t, reg) mockPullRequestUpdateLabels(t, reg) + mockProjectV2ItemUpdate(t, reg) }, stdout: "https://github.com/OWNER/REPO/pull/123\n", }, @@ -427,6 +437,8 @@ func Test_editRun(t *testing.T) { mockRepoMetadata(t, reg, false) mockPullRequestUpdate(t, reg) mockPullRequestReviewersUpdate(t, reg) + mockPullRequestUpdateLabels(t, reg) + mockProjectV2ItemUpdate(t, reg) }, stdout: "https://github.com/OWNER/REPO/pull/123\n", }, @@ -445,6 +457,8 @@ func Test_editRun(t *testing.T) { httpStubs: func(t *testing.T, reg *httpmock.Registry) { mockRepoMetadata(t, reg, true) mockPullRequestUpdate(t, reg) + mockPullRequestUpdateLabels(t, reg) + mockProjectV2ItemUpdate(t, reg) }, stdout: "https://github.com/OWNER/REPO/pull/123\n", }, @@ -465,6 +479,7 @@ func Test_editRun(t *testing.T) { tt.input.HttpClient = httpClient t.Run(tt.name, func(t *testing.T) { + fmt.Println(tt.name) err := editRun(tt.input) assert.NoError(t, err) assert.Equal(t, tt.stdout, stdout.String()) @@ -530,6 +545,27 @@ func mockRepoMetadata(_ *testing.T, reg *httpmock.Registry, skipReviewers bool) "pageInfo": { "hasNextPage": false } } } } } `)) + reg.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "CleanupV2", "id": "CLEANUPV2ID" }, + { "title": "RoadmapV2", "id": "ROADMAPV2ID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [ + { "title": "TriageV2", "id": "TRIAGEV2ID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) if !skipReviewers { reg.Register( httpmock.GraphQL(`query OrganizationTeamList\b`), @@ -577,6 +613,15 @@ func mockPullRequestUpdateLabels(t *testing.T, reg *httpmock.Registry) { ) } +func mockProjectV2ItemUpdate(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation UpdateProjectV2Items\b`), + httpmock.GraphQLMutation(` + { "data": { "add_000": { "item": { "id": "1" } }, "delete_001": { "item": { "id": "2" } } } }`, + func(inputs map[string]interface{}) {}), + ) +} + type testFetcher struct{} type testSurveyor struct { skipReviewers bool @@ -608,7 +653,9 @@ func (s testSurveyor) EditFields(e *shared.Editable, _ string) error { } e.Assignees.Value = []string{"monalisa", "hubot"} e.Labels.Value = []string{"feature", "TODO", "bug"} - e.Projects.Value = []string{"Cleanup", "Roadmap"} + e.Labels.Add = []string{"feature", "TODO", "bug"} + e.Labels.Remove = []string{"docs"} + e.Projects.Value = []string{"Cleanup", "CleanupV2"} e.Milestone.Value = "GA" return nil } diff --git a/pkg/cmd/pr/shared/completion.go b/pkg/cmd/pr/shared/completion.go new file mode 100644 index 000000000..e07abc5a7 --- /dev/null +++ b/pkg/cmd/pr/shared/completion.go @@ -0,0 +1,39 @@ +package shared + +import ( + "fmt" + "net/http" + "sort" + "strings" + "time" + + "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/internal/ghrepo" +) + +func RequestableReviewersForCompletion(httpClient *http.Client, repo ghrepo.Interface) ([]string, error) { + client := api.NewClientFromHTTP(api.NewCachedHTTPClient(httpClient, time.Minute*2)) + + metadata, err := api.RepoMetadata(client, repo, api.RepoMetadataInput{Reviewers: true}) + if err != nil { + return nil, err + } + + results := []string{} + for _, user := range metadata.AssignableUsers { + if strings.EqualFold(user.Login, metadata.CurrentLogin) { + continue + } + if user.Name != "" { + results = append(results, fmt.Sprintf("%s\t%s", user.Login, user.Name)) + } else { + results = append(results, user.Login) + } + } + for _, team := range metadata.Teams { + results = append(results, fmt.Sprintf("%s/%s", repo.RepoOwner(), team.Slug)) + } + + sort.Strings(results) + return results, nil +} diff --git a/pkg/cmd/pr/shared/editable.go b/pkg/cmd/pr/shared/editable.go index 57b042937..7564c24cd 100644 --- a/pkg/cmd/pr/shared/editable.go +++ b/pkg/cmd/pr/shared/editable.go @@ -18,7 +18,7 @@ type Editable struct { Reviewers EditableSlice Assignees EditableSlice Labels EditableSlice - Projects EditableSlice + Projects EditableProjects Milestone EditableString Metadata api.RepoMetadataResult } @@ -40,6 +40,13 @@ type EditableSlice struct { Allowed bool } +// ProjectsV2 mutations require a mapping of an item ID to a project ID. +// Keep that map along with standard EditableSlice data. +type EditableProjects struct { + EditableSlice + ProjectItems map[string]string +} + func (e Editable) Dirty() bool { return e.Title.Edited || e.Body.Edited || @@ -120,6 +127,7 @@ func (e Editable) AssigneeIds(client *api.Client, repo ghrepo.Interface) (*[]str return &a, err } +// ProjectIds returns a slice containing IDs of projects v1 that the issue or a PR has to be linked to. func (e Editable) ProjectIds() (*[]string, error) { if !e.Projects.Edited { return nil, nil @@ -131,10 +139,56 @@ func (e Editable) ProjectIds() (*[]string, error) { s.RemoveValues(e.Projects.Remove) e.Projects.Value = s.ToSlice() } - p, err := e.Metadata.ProjectsToIDs(e.Projects.Value) + p, _, err := e.Metadata.ProjectsToIDs(e.Projects.Value) return &p, err } +// ProjectV2Ids returns a pair of slices. +// The first is the projects the item should be added to. +// The second is the projects the items should be removed from. +func (e Editable) ProjectV2Ids() (*[]string, *[]string, error) { + if !e.Projects.Edited { + return nil, nil, nil + } + + // titles of projects to add + addTitles := set.NewStringSet() + // titles of projects to remove + removeTitles := set.NewStringSet() + + if len(e.Projects.Add) != 0 || len(e.Projects.Remove) != 0 { + // Projects were selected using flags. + addTitles.AddValues(e.Projects.Add) + removeTitles.AddValues(e.Projects.Remove) + } else { + // Projects were selected interactively. + addTitles.AddValues(e.Projects.Value) + addTitles.RemoveValues(e.Projects.Default) + removeTitles.AddValues(e.Projects.Default) + removeTitles.RemoveValues(e.Projects.Value) + } + + var addIds []string + var removeIds []string + var err error + + if addTitles.Len() > 0 { + _, addIds, err = e.Metadata.ProjectsToIDs(addTitles.ToSlice()) + if err != nil { + return nil, nil, err + } + } + + if removeTitles.Len() > 0 { + _, removeIds, err = e.Metadata.ProjectsToIDs(removeTitles.ToSlice()) + if err != nil { + return nil, nil, err + } + } + + return &addIds, &removeIds, nil +} + func (e Editable) MilestoneId() (*string, error) { if !e.Milestone.Edited { return nil, nil @@ -285,8 +339,11 @@ func FetchOptions(client *api.Client, repo ghrepo.Interface, editable *Editable) labels = append(labels, l.Name) } var projects []string - for _, l := range metadata.Projects { - projects = append(projects, l.Name) + for _, p := range metadata.Projects { + projects = append(projects, p.Name) + } + for _, p := range metadata.ProjectsV2 { + projects = append(projects, p.Title) } milestones := []string{noMilestone} for _, m := range metadata.Milestones { diff --git a/pkg/cmd/pr/shared/editable_http.go b/pkg/cmd/pr/shared/editable_http.go index 07bcb9a7f..3353cd5f5 100644 --- a/pkg/cmd/pr/shared/editable_http.go +++ b/pkg/cmd/pr/shared/editable_http.go @@ -35,6 +35,29 @@ func UpdateIssue(httpClient *http.Client, repo ghrepo.Interface, id string, isPR } } + // updateIssue mutation does not support ProjectsV2 so do them in a seperate request. + if options.Projects.Edited { + wg.Go(func() error { + apiClient := api.NewClientFromHTTP(httpClient) + addIds, removeIds, err := options.ProjectV2Ids() + if err != nil { + return err + } + if addIds == nil && removeIds == nil { + return nil + } + toAdd := make(map[string]string, len(*addIds)) + toRemove := make(map[string]string, len(*removeIds)) + for _, p := range *addIds { + toAdd[p] = id + } + for _, p := range *removeIds { + toRemove[p] = options.Projects.ProjectItems[p] + } + return api.UpdateProjectV2Items(apiClient, repo, toAdd, toRemove) + }) + } + if dirtyExcludingLabels(options) { wg.Go(func() error { return replaceIssueFields(httpClient, repo, id, isPR, options) diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index 368134871..f1886b3fe 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -139,7 +139,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err fields := set.NewStringSet() fields.AddValues(opts.Fields) numberFieldOnly := fields.Len() == 1 && fields.Contains("number") - fields.Add("id") // for additional preload queries below + fields.AddValues([]string{"id", "number"}) // for additional preload queries below if fields.Contains("isInMergeQueue") || fields.Contains("isMergeQueueEnabled") { cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) @@ -154,6 +154,12 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err } } + var getProjectItems bool + if fields.Contains("projectItems") { + getProjectItems = true + fields.Remove("projectItems") + } + var pr *api.PullRequest if f.prNumber > 0 { if numberFieldOnly { @@ -184,6 +190,16 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err return preloadPrChecks(httpClient, f.repo, pr) }) } + if getProjectItems { + g.Go(func() error { + apiClient := api.NewClientFromHTTP(httpClient) + err := api.ProjectsV2ItemsForPullRequest(apiClient, f.repo, pr) + if err != nil && !api.ProjectsV2IgnorableError(err) { + return err + } + return nil + }) + } return pr, f.repo, g.Wait() } diff --git a/pkg/cmd/pr/shared/params.go b/pkg/cmd/pr/shared/params.go index 5b13cf681..1b82c33f5 100644 --- a/pkg/cmd/pr/shared/params.go +++ b/pkg/cmd/pr/shared/params.go @@ -109,11 +109,12 @@ func AddMetadataToIssueParams(client *api.Client, baseRepo ghrepo.Interface, par } params["labelIds"] = labelIDs - projectIDs, err := tb.MetadataResult.ProjectsToIDs(tb.Projects) + projectIDs, projectV2IDs, err := tb.MetadataResult.ProjectsToIDs(tb.Projects) if err != nil { return fmt.Errorf("could not add to project: %w", err) } params["projectIds"] = projectIDs + params["projectV2Ids"] = projectV2IDs if len(tb.Milestones) > 0 { milestoneID, err := tb.MetadataResult.MilestoneToID(tb.Milestones[0]) diff --git a/pkg/cmd/pr/shared/survey.go b/pkg/cmd/pr/shared/survey.go index d764e2680..8582373b0 100644 --- a/pkg/cmd/pr/shared/survey.go +++ b/pkg/cmd/pr/shared/survey.go @@ -203,8 +203,11 @@ func MetadataSurvey(io *iostreams.IOStreams, baseRepo ghrepo.Interface, fetcher labels = append(labels, l.Name) } var projects []string - for _, l := range metadataResult.Projects { - projects = append(projects, l.Name) + for _, p := range metadataResult.Projects { + projects = append(projects, p.Name) + } + for _, p := range metadataResult.ProjectsV2 { + projects = append(projects, p.Title) } milestones := []string{noMilestone} for _, m := range metadataResult.Milestones { diff --git a/pkg/cmd/release/list/list.go b/pkg/cmd/release/list/list.go index 13e0cf8f3..f2258a933 100644 --- a/pkg/cmd/release/list/list.go +++ b/pkg/cmd/release/list/list.go @@ -3,6 +3,7 @@ package list import ( "fmt" "net/http" + "time" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/tableprinter" @@ -107,7 +108,7 @@ func listRun(opts *ListOptions) error { if rel.PublishedAt.IsZero() { pubDate = rel.CreatedAt } - table.AddTimeField(pubDate, iofmt.Gray) + table.AddTimeField(time.Now(), pubDate, iofmt.Gray) table.EndRow() } err = table.Render() diff --git a/pkg/cmd/repo/archive/archive.go b/pkg/cmd/repo/archive/archive.go index 9e18ea1c1..50b1e7908 100644 --- a/pkg/cmd/repo/archive/archive.go +++ b/pkg/cmd/repo/archive/archive.go @@ -47,16 +47,20 @@ With no argument, archives the current repository.`), } if !opts.Confirmed && !opts.IO.CanPrompt() { - return cmdutil.FlagErrorf("--confirm required when not running interactively") + return cmdutil.FlagErrorf("--yes required when not running interactively") } + if runF != nil { return runF(opts) } + return archiveRun(opts) }, } - cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt") + cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt") return cmd } diff --git a/pkg/cmd/repo/archive/archive_test.go b/pkg/cmd/repo/archive/archive_test.go index 6d681e784..02aab3383 100644 --- a/pkg/cmd/repo/archive/archive_test.go +++ b/pkg/cmd/repo/archive/archive_test.go @@ -26,7 +26,7 @@ func TestNewCmdArchive(t *testing.T) { { name: "no arguments no tty", input: "", - errMsg: "--confirm required when not running interactively", + errMsg: "--yes required when not running interactively", wantErr: true, }, { diff --git a/pkg/cmd/repo/clone/clone.go b/pkg/cmd/repo/clone/clone.go index cb46d4572..4a5ce3e82 100644 --- a/pkg/cmd/repo/clone/clone.go +++ b/pkg/cmd/repo/clone/clone.go @@ -179,6 +179,10 @@ func cloneRun(opts *CloneOptions) error { if err != nil { return err } + + if err := gitClient.Fetch(ctx, upstreamName, "", git.WithRepoDir(cloneDir)); err != nil { + return err + } } return nil } diff --git a/pkg/cmd/repo/clone/clone_test.go b/pkg/cmd/repo/clone/clone_test.go index 4e7750522..98fc3736c 100644 --- a/pkg/cmd/repo/clone/clone_test.go +++ b/pkg/cmd/repo/clone/clone_test.go @@ -245,7 +245,8 @@ func Test_RepoClone_hasParent(t *testing.T) { defer cmdTeardown(t) cs.Register(`git clone https://github.com/OWNER/REPO.git`, 0, "") - cs.Register(`git -C REPO remote add -t trunk -f upstream https://github.com/hubot/ORIG.git`, 0, "") + cs.Register(`git -C REPO remote add -t trunk upstream https://github.com/hubot/ORIG.git`, 0, "") + cs.Register(`git -C REPO fetch upstream`, 0, "") _, err := runCloneCommand(httpClient, "OWNER/REPO") if err != nil { @@ -281,7 +282,8 @@ func Test_RepoClone_hasParent_upstreamRemoteName(t *testing.T) { defer cmdTeardown(t) cs.Register(`git clone https://github.com/OWNER/REPO.git`, 0, "") - cs.Register(`git -C REPO remote add -t trunk -f test https://github.com/hubot/ORIG.git`, 0, "") + cs.Register(`git -C REPO remote add -t trunk test https://github.com/hubot/ORIG.git`, 0, "") + cs.Register(`git -C REPO fetch test`, 0, "") _, err := runCloneCommand(httpClient, "OWNER/REPO --upstream-remote-name test") if err != nil { diff --git a/pkg/cmd/repo/delete/delete.go b/pkg/cmd/repo/delete/delete.go index 4f1a5b882..9de2caab1 100644 --- a/pkg/cmd/repo/delete/delete.go +++ b/pkg/cmd/repo/delete/delete.go @@ -48,18 +48,22 @@ To authorize, run "gh auth refresh -s delete_repo"`, if len(args) > 0 { opts.RepoArg = args[0] } + if !opts.IO.CanPrompt() && !opts.Confirmed { - return cmdutil.FlagErrorf("--confirm required when not running interactively") + return cmdutil.FlagErrorf("--yes required when not running interactively") } if runF != nil { return runF(opts) } + return deleteRun(opts) }, } cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "confirm deletion without prompting") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVar(&opts.Confirmed, "yes", false, "confirm deletion without prompting") return cmd } diff --git a/pkg/cmd/repo/delete/delete_test.go b/pkg/cmd/repo/delete/delete_test.go index a08cf3473..74f8a5699 100644 --- a/pkg/cmd/repo/delete/delete_test.go +++ b/pkg/cmd/repo/delete/delete_test.go @@ -29,12 +29,18 @@ func TestNewCmdDelete(t *testing.T) { input: "OWNER/REPO --confirm", output: DeleteOptions{RepoArg: "OWNER/REPO", Confirmed: true}, }, + { + name: "yes flag", + tty: true, + input: "OWNER/REPO --yes", + output: DeleteOptions{RepoArg: "OWNER/REPO", Confirmed: true}, + }, { name: "no confirmation notty", input: "OWNER/REPO", output: DeleteOptions{RepoArg: "OWNER/REPO"}, wantErr: true, - errMsg: "--confirm required when not running interactively", + errMsg: "--yes required when not running interactively", }, { name: "base repo resolution", diff --git a/pkg/cmd/repo/edit/edit.go b/pkg/cmd/repo/edit/edit.go index b5e4db996..bbbd94dd5 100644 --- a/pkg/cmd/repo/edit/edit.go +++ b/pkg/cmd/repo/edit/edit.go @@ -16,6 +16,7 @@ import ( fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghinstance" "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/internal/prompter" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/pkg/prompt" @@ -36,6 +37,7 @@ const ( optionIssues = "Issues" optionMergeOptions = "Merge Options" optionProjects = "Projects" + optionDiscussions = "Discussions" optionTemplateRepo = "Template Repository" optionTopics = "Topics" optionVisibility = "Visibility" @@ -51,6 +53,7 @@ type EditOptions struct { RemoveTopics []string InteractiveMode bool Detector fd.Detector + Prompter prompter.Prompter // Cache of current repo topics to avoid retrieving them // in multiple flows. topicsCache []string @@ -66,6 +69,7 @@ type EditRepositoryInput struct { EnableIssues *bool `json:"has_issues,omitempty"` EnableMergeCommit *bool `json:"allow_merge_commit,omitempty"` EnableProjects *bool `json:"has_projects,omitempty"` + EnableDiscussions *bool `json:"has_discussions,omitempty"` EnableRebaseMerge *bool `json:"allow_rebase_merge,omitempty"` EnableSquashMerge *bool `json:"allow_squash_merge,omitempty"` EnableWiki *bool `json:"has_wiki,omitempty"` @@ -76,7 +80,8 @@ type EditRepositoryInput struct { func NewCmdEdit(f *cmdutil.Factory, runF func(options *EditOptions) error) *cobra.Command { opts := &EditOptions{ - IO: f.IOStreams, + IO: f.IOStreams, + Prompter: f.Prompter, } cmd := &cobra.Command{ @@ -93,6 +98,8 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(options *EditOptions) error) *cobr Edit repository settings. To toggle a setting off, use the %[1]s--flag=false%[1]s syntax. + + Note that changing repository visibility to private will cause loss of stars and watchers. `, "`"), Args: cobra.MaximumNArgs(1), Example: heredoc.Doc(` @@ -146,6 +153,7 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(options *EditOptions) error) *cobr cmdutil.NilBoolFlag(cmd, &opts.Edits.EnableIssues, "enable-issues", "", "Enable issues in the repository") cmdutil.NilBoolFlag(cmd, &opts.Edits.EnableProjects, "enable-projects", "", "Enable projects in the repository") cmdutil.NilBoolFlag(cmd, &opts.Edits.EnableWiki, "enable-wiki", "", "Enable wiki in the repository") + cmdutil.NilBoolFlag(cmd, &opts.Edits.EnableDiscussions, "enable-discussions", "", "Enable discussions in the repository") cmdutil.NilBoolFlag(cmd, &opts.Edits.EnableMergeCommit, "enable-merge-commit", "", "Enable merging pull requests via merge commit") cmdutil.NilBoolFlag(cmd, &opts.Edits.EnableSquashMerge, "enable-squash-merge", "", "Enable merging pull requests via squashed commit") cmdutil.NilBoolFlag(cmd, &opts.Edits.EnableRebaseMerge, "enable-rebase-merge", "", "Enable merging pull requests via rebase") @@ -181,13 +189,17 @@ func editRun(ctx context.Context, opts *EditOptions) error { "hasIssuesEnabled", "hasProjectsEnabled", "hasWikiEnabled", + // TODO: GitHub Enterprise Server does not support has_discussions yet + // "hasDiscussionsEnabled", "homepageUrl", "isInOrganization", "isTemplate", "mergeCommitAllowed", "rebaseMergeAllowed", "repositoryTopics", + "stargazerCount", "squashMergeAllowed", + "watchers", } if repoFeatures.VisibilityField { fieldsToRetrieve = append(fieldsToRetrieve, "visibility") @@ -275,6 +287,8 @@ func interactiveChoice(r *api.Repository) ([]string, error) { optionIssues, optionMergeOptions, optionProjects, + // TODO: GitHub Enterprise Server does not support has_discussions yet + // optionDiscussions, optionTemplateRepo, optionTopics, optionVisibility, @@ -385,17 +399,36 @@ func interactiveRepoEdit(opts *EditOptions, r *api.Repository) error { if err != nil { return err } - case optionVisibility: - opts.Edits.Visibility = &r.Visibility + case optionDiscussions: + opts.Edits.EnableDiscussions = &r.HasDiscussionsEnabled //nolint:staticcheck // SA1019: prompt.SurveyAskOne is deprecated: use Prompter - err = prompt.SurveyAskOne(&survey.Select{ - Message: "Visibility", - Options: []string{"public", "private", "internal"}, - Default: strings.ToLower(r.Visibility), - }, opts.Edits.Visibility) + err = prompt.SurveyAskOne(&survey.Confirm{ + Message: "Enable Discussions?", + Default: r.HasDiscussionsEnabled, + }, opts.Edits.EnableDiscussions) if err != nil { return err } + case optionVisibility: + opts.Edits.Visibility = &r.Visibility + visibilityOptions := []string{"public", "private", "internal"} + selected, err := opts.Prompter.Select("Visibility", strings.ToLower(r.Visibility), visibilityOptions) + if err != nil { + return err + } + confirmed := true + if visibilityOptions[selected] == "private" && + (r.StargazerCount > 0 || r.Watchers.TotalCount > 0) { + cs := opts.IO.ColorScheme() + fmt.Fprintf(opts.IO.ErrOut, "%s Changing the repository visibility to private will cause permanent loss of stars and watchers.\n", cs.WarningIcon()) + confirmed, err = opts.Prompter.Confirm("Do you want to change visibility to private?", false) + if err != nil { + return err + } + } + if confirmed { + opts.Edits.Visibility = &visibilityOptions[selected] + } case optionMergeOptions: var defaultMergeOptions []string var selectedMergeOptions []string diff --git a/pkg/cmd/repo/fork/fork.go b/pkg/cmd/repo/fork/fork.go index e17e4b34d..0721485a8 100644 --- a/pkg/cmd/repo/fork/fork.go +++ b/pkg/cmd/repo/fork/fork.go @@ -33,16 +33,17 @@ type ForkOptions struct { Remotes func() (ghContext.Remotes, error) Since func(time.Time) time.Duration - GitArgs []string - Repository string - Clone bool - Remote bool - PromptClone bool - PromptRemote bool - RemoteName string - Organization string - ForkName string - Rename bool + GitArgs []string + Repository string + Clone bool + Remote bool + PromptClone bool + PromptRemote bool + RemoteName string + Organization string + ForkName string + Rename bool + DefaultBranchOnly bool } // TODO warn about useless flags (--remote, --remote-name) when running from outside a repository @@ -122,6 +123,7 @@ func NewCmdFork(f *cmdutil.Factory, runF func(*ForkOptions) error) *cobra.Comman cmd.Flags().StringVar(&opts.RemoteName, "remote-name", defaultRemoteName, "Specify the name for the new remote") cmd.Flags().StringVar(&opts.Organization, "org", "", "Create the fork in an organization") cmd.Flags().StringVar(&opts.ForkName, "fork-name", "", "Rename the forked repository") + cmd.Flags().BoolVar(&opts.DefaultBranchOnly, "default-branch-only", false, "Only include the default branch in the fork") return cmd } @@ -181,7 +183,7 @@ func forkRun(opts *ForkOptions) error { apiClient := api.NewClientFromHTTP(httpClient) opts.IO.StartProgressIndicator() - forkedRepo, err := api.ForkRepo(apiClient, repoToFork, opts.Organization, opts.ForkName) + forkedRepo, err := api.ForkRepo(apiClient, repoToFork, opts.Organization, opts.ForkName, opts.DefaultBranchOnly) opts.IO.StopProgressIndicator() if err != nil { return fmt.Errorf("failed to fork: %w", err) @@ -327,6 +329,10 @@ func forkRun(opts *ForkOptions) error { return err } + if err := gitClient.Fetch(ctx, "upstream", "", git.WithRepoDir(cloneDir)); err != nil { + return err + } + if connectedToTerminal { fmt.Fprintf(stderr, "%s Cloned fork\n", cs.SuccessIcon()) } diff --git a/pkg/cmd/repo/fork/fork_test.go b/pkg/cmd/repo/fork/fork_test.go index e913583e8..8022cd2dc 100644 --- a/pkg/cmd/repo/fork/fork_test.go +++ b/pkg/cmd/repo/fork/fork_test.go @@ -234,7 +234,7 @@ func TestRepoFork(t *testing.T) { }, httpStubs: forkPost, execStubs: func(cs *run.CommandStubber) { - cs.Register(`git remote add -f fork https://github\.com/someone/REPO\.git`, 0, "") + cs.Register(`git remote add fork https://github\.com/someone/REPO\.git`, 0, "") }, wantErrOut: "✓ Created fork someone/REPO\n✓ Added remote fork\n", }, @@ -258,7 +258,7 @@ func TestRepoFork(t *testing.T) { }, httpStubs: forkPost, execStubs: func(cs *run.CommandStubber) { - cs.Register(`git remote add -f fork git@github\.com:someone/REPO\.git`, 0, "") + cs.Register(`git remote add fork git@github\.com:someone/REPO\.git`, 0, "") }, wantErrOut: "✓ Created fork someone/REPO\n✓ Added remote fork\n", }, @@ -288,7 +288,7 @@ func TestRepoFork(t *testing.T) { httpStubs: forkPost, execStubs: func(cs *run.CommandStubber) { cs.Register("git remote rename origin upstream", 0, "") - cs.Register(`git remote add -f origin https://github.com/someone/REPO.git`, 0, "") + cs.Register(`git remote add origin https://github.com/someone/REPO.git`, 0, "") }, askStubs: func(as *prompt.AskStubber) { //nolint:staticcheck // SA1019: as.StubOne is deprecated: use StubPrompt @@ -364,7 +364,7 @@ func TestRepoFork(t *testing.T) { httpStubs: forkPost, execStubs: func(cs *run.CommandStubber) { cs.Register("git remote rename origin upstream", 0, "") - cs.Register(`git remote add -f origin https://github.com/someone/REPO.git`, 0, "") + cs.Register(`git remote add origin https://github.com/someone/REPO.git`, 0, "") }, wantErrOut: "✓ Created fork someone/REPO\n✓ Added remote origin\n", }, @@ -418,7 +418,7 @@ func TestRepoFork(t *testing.T) { httpStubs: forkPost, execStubs: func(cs *run.CommandStubber) { cs.Register("git remote rename origin upstream", 0, "") - cs.Register(`git remote add -f origin https://github.com/someone/REPO.git`, 0, "") + cs.Register(`git remote add origin https://github.com/someone/REPO.git`, 0, "") }, }, { @@ -437,7 +437,8 @@ func TestRepoFork(t *testing.T) { httpStubs: forkPost, execStubs: func(cs *run.CommandStubber) { cs.Register(`git clone --depth 1 https://github.com/someone/REPO\.git`, 0, "") - cs.Register(`git -C REPO remote add -f upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO remote add upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO fetch upstream`, 0, "") }, wantErrOut: "✓ Created fork someone/REPO\n✓ Cloned fork\n", }, @@ -467,7 +468,8 @@ func TestRepoFork(t *testing.T) { }, execStubs: func(cs *run.CommandStubber) { cs.Register(`git clone https://github.com/gamehendge/REPO\.git`, 0, "") - cs.Register(`git -C REPO remote add -f upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO remote add upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO fetch upstream`, 0, "") }, wantErrOut: "✓ Created fork gamehendge/REPO\n✓ Cloned fork\n", }, @@ -481,7 +483,8 @@ func TestRepoFork(t *testing.T) { httpStubs: forkPost, execStubs: func(cs *run.CommandStubber) { cs.Register(`git clone https://github.com/someone/REPO\.git`, 0, "") - cs.Register(`git -C REPO remote add -f upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO remote add upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO fetch upstream`, 0, "") }, wantErrOut: "✓ Created fork someone/REPO\n✓ Cloned fork\n", }, @@ -513,7 +516,8 @@ func TestRepoFork(t *testing.T) { }, execStubs: func(cs *run.CommandStubber) { cs.Register(`git clone https://github.com/someone/REPO\.git`, 0, "") - cs.Register(`git -C REPO remote add -f upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO remote add upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO fetch upstream`, 0, "") }, wantErrOut: "✓ Created fork someone/REPO\n✓ Cloned fork\n", }, @@ -534,7 +538,8 @@ func TestRepoFork(t *testing.T) { }, execStubs: func(cs *run.CommandStubber) { cs.Register(`git clone https://github.com/someone/REPO\.git`, 0, "") - cs.Register(`git -C REPO remote add -f upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO remote add upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO fetch upstream`, 0, "") }, wantErrOut: "! someone/REPO already exists\n✓ Cloned fork\n", }, @@ -568,7 +573,8 @@ func TestRepoFork(t *testing.T) { httpStubs: forkPost, execStubs: func(cs *run.CommandStubber) { cs.Register(`git clone https://github.com/someone/REPO\.git`, 0, "") - cs.Register(`git -C REPO remote add -f upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO remote add upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO fetch upstream`, 0, "") }, wantErrOut: "someone/REPO already exists", }, @@ -581,7 +587,8 @@ func TestRepoFork(t *testing.T) { httpStubs: forkPost, execStubs: func(cs *run.CommandStubber) { cs.Register(`git clone https://github.com/someone/REPO\.git`, 0, "") - cs.Register(`git -C REPO remote add -f upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO remote add upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO fetch upstream`, 0, "") }, }, { diff --git a/pkg/cmd/repo/rename/rename.go b/pkg/cmd/repo/rename/rename.go index 2d07cd8b0..f979e8101 100644 --- a/pkg/cmd/repo/rename/rename.go +++ b/pkg/cmd/repo/rename/rename.go @@ -60,7 +60,7 @@ func NewCmdRename(f *cmdutil.Factory, runf func(*RenameOptions) error) *cobra.Co if len(args) == 1 && !confirm && !opts.HasRepoOverride { if !opts.IO.CanPrompt() { - return cmdutil.FlagErrorf("--confirm required when passing a single argument") + return cmdutil.FlagErrorf("--yes required when passing a single argument") } opts.DoConfirm = true } @@ -68,12 +68,15 @@ func NewCmdRename(f *cmdutil.Factory, runf func(*RenameOptions) error) *cobra.Co if runf != nil { return runf(opts) } + return renameRun(opts) }, } cmdutil.EnableRepoOverride(cmd, f) - cmd.Flags().BoolVarP(&confirm, "confirm", "y", false, "skip confirmation prompt") + cmd.Flags().BoolVar(&confirm, "confirm", false, "Skip confirmation prompt") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVarP(&confirm, "yes", "y", false, "Skip the confirmation prompt") return cmd } diff --git a/pkg/cmd/repo/rename/rename_test.go b/pkg/cmd/repo/rename/rename_test.go index 523b2ba4a..611e3dcda 100644 --- a/pkg/cmd/repo/rename/rename_test.go +++ b/pkg/cmd/repo/rename/rename_test.go @@ -35,7 +35,7 @@ func TestNewCmdRename(t *testing.T) { }, { name: "one argument no tty confirmed", - input: "REPO --confirm", + input: "REPO --yes", output: RenameOptions{ newRepoSelector: "REPO", }, @@ -43,12 +43,12 @@ func TestNewCmdRename(t *testing.T) { { name: "one argument no tty", input: "REPO", - errMsg: "--confirm required when passing a single argument", + errMsg: "--yes required when passing a single argument", wantErr: true, }, { name: "one argument tty confirmed", - input: "REPO --confirm", + input: "REPO --yes", tty: true, output: RenameOptions{ newRepoSelector: "REPO", diff --git a/pkg/cmd/repo/setdefault/setdefault.go b/pkg/cmd/repo/setdefault/setdefault.go index 0e25a2c25..2cf8b19b9 100644 --- a/pkg/cmd/repo/setdefault/setdefault.go +++ b/pkg/cmd/repo/setdefault/setdefault.go @@ -90,9 +90,9 @@ func NewCmdSetDefault(f *cmdutil.Factory, runF func(*SetDefaultOptions) error) * return cmdutil.FlagErrorf("repository required when not running interactively") } - c := &git.Client{} - - if !c.InGitDirectory(ctx.Background()) { + if isLocal, err := opts.GitClient.IsLocalGitRepo(cmd.Context()); err != nil { + return err + } else if !isLocal { return errors.New("must be run from inside a git repository") } @@ -160,7 +160,7 @@ func setDefaultRun(opts *SetDefaultOptions) error { return err } - knownRepos, err := resolvedRemotes.NetworkRepos() + knownRepos, err := resolvedRemotes.NetworkRepos(0) if err != nil { return err } diff --git a/pkg/cmd/repo/setdefault/setdefault_test.go b/pkg/cmd/repo/setdefault/setdefault_test.go index c90238548..a1c1f44ac 100644 --- a/pkg/cmd/repo/setdefault/setdefault_test.go +++ b/pkg/cmd/repo/setdefault/setdefault_test.go @@ -29,7 +29,7 @@ func TestNewCmdSetDefault(t *testing.T) { { name: "no argument", gitStubs: func(cs *run.CommandStubber) { - cs.Register(`git rev-parse --is-inside-work-tree`, 0, "true") + cs.Register(`git rev-parse --git-dir`, 0, ".git") }, input: "", output: SetDefaultOptions{}, @@ -37,7 +37,7 @@ func TestNewCmdSetDefault(t *testing.T) { { name: "repo argument", gitStubs: func(cs *run.CommandStubber) { - cs.Register(`git rev-parse --is-inside-work-tree`, 0, "true") + cs.Register(`git rev-parse --git-dir`, 0, ".git") }, input: "cli/cli", output: SetDefaultOptions{Repo: ghrepo.New("cli", "cli")}, @@ -52,7 +52,7 @@ func TestNewCmdSetDefault(t *testing.T) { { name: "view flag", gitStubs: func(cs *run.CommandStubber) { - cs.Register(`git rev-parse --is-inside-work-tree`, 0, "true") + cs.Register(`git rev-parse --git-dir`, 0, ".git") }, input: "--view", output: SetDefaultOptions{ViewMode: true}, @@ -60,7 +60,7 @@ func TestNewCmdSetDefault(t *testing.T) { { name: "unset flag", gitStubs: func(cs *run.CommandStubber) { - cs.Register(`git rev-parse --is-inside-work-tree`, 0, "true") + cs.Register(`git rev-parse --git-dir`, 0, ".git") }, input: "--unset", output: SetDefaultOptions{UnsetMode: true}, @@ -68,7 +68,7 @@ func TestNewCmdSetDefault(t *testing.T) { { name: "run from non-git directory", gitStubs: func(cs *run.CommandStubber) { - cs.Register(`git rev-parse --is-inside-work-tree`, 1, "") + cs.Register(`git rev-parse --git-dir`, 128, "") }, input: "", wantErr: true, @@ -83,6 +83,7 @@ func TestNewCmdSetDefault(t *testing.T) { io.SetStderrTTY(true) f := &cmdutil.Factory{ IOStreams: io, + GitClient: &git.Client{GitPath: "/fake/path/to/git"}, } var gotOpts *SetDefaultOptions @@ -121,6 +122,9 @@ func TestDefaultRun(t *testing.T) { repo1, _ := ghrepo.FromFullName("OWNER/REPO") repo2, _ := ghrepo.FromFullName("OWNER2/REPO2") repo3, _ := ghrepo.FromFullName("OWNER3/REPO3") + repo4, _ := ghrepo.FromFullName("OWNER4/REPO4") + repo5, _ := ghrepo.FromFullName("OWNER5/REPO5") + repo6, _ := ghrepo.FromFullName("OWNER6/REPO6") tests := []struct { name string @@ -391,6 +395,55 @@ func TestDefaultRun(t *testing.T) { }, wantStdout: "Found only one known remote repo, OWNER2/REPO2 on github.com.\n✓ Set OWNER2/REPO2 as the default repository for the current directory\n", }, + { + name: "interactive mode more than five remotes", + tty: true, + opts: SetDefaultOptions{}, + remotes: []*context.Remote{ + {Remote: &git.Remote{Name: "origin"}, Repo: repo1}, + {Remote: &git.Remote{Name: "upstream"}, Repo: repo2}, + {Remote: &git.Remote{Name: "other1"}, Repo: repo3}, + {Remote: &git.Remote{Name: "other2"}, Repo: repo4}, + {Remote: &git.Remote{Name: "other3"}, Repo: repo5}, + {Remote: &git.Remote{Name: "other4"}, Repo: repo6}, + }, + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`query RepositoryNetwork\b`), + httpmock.GraphQLQuery(`{"data":{ + "repo_000":{"name":"REPO","owner":{"login":"OWNER"}}, + "repo_001":{"name":"REPO2","owner":{"login":"OWNER2"}}, + "repo_002":{"name":"REPO3","owner":{"login":"OWNER3"}}, + "repo_003":{"name":"REPO4","owner":{"login":"OWNER4"}}, + "repo_004":{"name":"REPO5","owner":{"login":"OWNER5"}}, + "repo_005":{"name":"REPO6","owner":{"login":"OWNER6"}} + }}`, + func(query string, inputs map[string]interface{}) { + assert.Contains(t, query, "repo_000") + assert.Contains(t, query, "repo_001") + assert.Contains(t, query, "repo_002") + assert.Contains(t, query, "repo_003") + assert.Contains(t, query, "repo_004") + assert.Contains(t, query, "repo_005") + }), + ) + }, + gitStubs: func(cs *run.CommandStubber) { + cs.Register(`git config --add remote.upstream.gh-resolved base`, 0, "") + }, + prompterStubs: func(pm *prompter.PrompterMock) { + pm.SelectFunc = func(p, d string, opts []string) (int, error) { + switch p { + case "Which repository should be the default?": + prompter.AssertOptions(t, []string{"OWNER/REPO", "OWNER2/REPO2", "OWNER3/REPO3", "OWNER4/REPO4", "OWNER5/REPO5", "OWNER6/REPO6"}, opts) + return prompter.IndexFor(opts, "OWNER2/REPO2") + default: + return -1, prompter.NoSuchPromptErr(p) + } + } + }, + wantStdout: "This command sets the default remote repository to use when querying the\nGitHub API for the locally cloned repository.\n\ngh uses the default repository for things like:\n\n - viewing and creating pull requests\n - viewing and creating issues\n - viewing and creating releases\n - working with Actions\n - adding repository and environment secrets\n\n✓ Set OWNER2/REPO2 as the default repository for the current directory\n", + }, } for _, tt := range tests { diff --git a/pkg/cmd/root/root.go b/pkg/cmd/root/root.go index 70d49b450..ce0a09cec 100644 --- a/pkg/cmd/root/root.go +++ b/pkg/cmd/root/root.go @@ -167,6 +167,7 @@ func newCodespaceCmd(f *cmdutil.Factory) *cobra.Command { &lazyLoadedHTTPClient{factory: f}, ), f.Browser, + f.Remotes, ) cmd := codespaceCmd.NewRootCmd(app) cmd.Use = "codespace" diff --git a/pkg/cmd/run/shared/shared.go b/pkg/cmd/run/shared/shared.go index 557882968..a1a71c4f9 100644 --- a/pkg/cmd/run/shared/shared.go +++ b/pkg/cmd/run/shared/shared.go @@ -77,7 +77,7 @@ type Run struct { workflowName string // cache column WorkflowID int64 `json:"workflow_id"` Number int64 `json:"run_number"` - Attempts uint8 `json:"run_attempt"` + Attempts uint64 `json:"run_attempt"` HeadBranch string `json:"head_branch"` JobsURL string `json:"jobs_url"` HeadCommit Commit `json:"head_commit"` diff --git a/pkg/cmd/search/commits/commits.go b/pkg/cmd/search/commits/commits.go new file mode 100644 index 000000000..7e98cae8a --- /dev/null +++ b/pkg/cmd/search/commits/commits.go @@ -0,0 +1,173 @@ +package commits + +import ( + "fmt" + "time" + + "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/internal/browser" + "github.com/cli/cli/v2/internal/tableprinter" + "github.com/cli/cli/v2/internal/text" + "github.com/cli/cli/v2/pkg/cmd/search/shared" + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/iostreams" + "github.com/cli/cli/v2/pkg/search" + "github.com/spf13/cobra" +) + +type CommitsOptions struct { + Browser browser.Browser + Exporter cmdutil.Exporter + IO *iostreams.IOStreams + Now time.Time + Query search.Query + Searcher search.Searcher + WebMode bool +} + +func NewCmdCommits(f *cmdutil.Factory, runF func(*CommitsOptions) error) *cobra.Command { + var order string + var sort string + opts := &CommitsOptions{ + Browser: f.Browser, + IO: f.IOStreams, + Query: search.Query{Kind: search.KindCommits}, + } + + cmd := &cobra.Command{ + Use: "commits []", + Short: "Search for commits", + Long: heredoc.Doc(` + Search for commits on GitHub. + + The command supports constructing queries using the GitHub search syntax, + using the parameter and qualifier flags, or a combination of the two. + + GitHub search syntax is documented at: + + `), + Example: heredoc.Doc(` + # search commits matching set of keywords "readme" and "typo" + $ gh search commits readme typo + + # search commits matching phrase "bug fix" + $ gh search commits "bug fix" + + # search commits committed by user "monalisa" + $ gh search commits --committer=monalisa + + # search commits authored by users with name "Jane Doe" + $ gh search commits --author-name="Jane Doe" + + # search commits matching hash "8dd03144ffdc6c0d486d6b705f9c7fba871ee7c3" + $ gh search commits --hash=8dd03144ffdc6c0d486d6b705f9c7fba871ee7c3 + + # search commits authored before February 1st, 2022 + $ gh search commits --author-date="<2022-02-01" + `), + RunE: func(c *cobra.Command, args []string) error { + if len(args) == 0 && c.Flags().NFlag() == 0 { + return cmdutil.FlagErrorf("specify search keywords or flags") + } + if opts.Query.Limit < 1 || opts.Query.Limit > shared.SearchMaxResults { + return cmdutil.FlagErrorf("`--limit` must be between 1 and 1000") + } + if c.Flags().Changed("order") { + opts.Query.Order = order + } + if c.Flags().Changed("sort") { + opts.Query.Sort = sort + } + opts.Query.Keywords = args + if runF != nil { + return runF(opts) + } + var err error + opts.Searcher, err = shared.Searcher(f) + if err != nil { + return err + } + return commitsRun(opts) + }, + } + + // Output flags + cmdutil.AddJSONFlags(cmd, &opts.Exporter, search.CommitFields) + cmd.Flags().BoolVarP(&opts.WebMode, "web", "w", false, "Open the search query in the web browser") + + // Query parameter flags + cmd.Flags().IntVarP(&opts.Query.Limit, "limit", "L", 30, "Maximum number of commits to fetch") + cmdutil.StringEnumFlag(cmd, &order, "order", "", "desc", []string{"asc", "desc"}, "Order of commits returned, ignored unless '--sort' flag is specified") + cmdutil.StringEnumFlag(cmd, &sort, "sort", "", "best-match", []string{"author-date", "committer-date"}, "Sort fetched commits") + + // Query qualifier flags + cmd.Flags().StringVar(&opts.Query.Qualifiers.Author, "author", "", "Filter by author") + cmd.Flags().StringVar(&opts.Query.Qualifiers.AuthorDate, "author-date", "", "Filter based on authored `date`") + cmd.Flags().StringVar(&opts.Query.Qualifiers.AuthorEmail, "author-email", "", "Filter on author email") + cmd.Flags().StringVar(&opts.Query.Qualifiers.AuthorName, "author-name", "", "Filter on author name") + cmd.Flags().StringVar(&opts.Query.Qualifiers.Committer, "committer", "", "Filter by committer") + cmd.Flags().StringVar(&opts.Query.Qualifiers.CommitterDate, "committer-date", "", "Filter based on committed `date`") + cmd.Flags().StringVar(&opts.Query.Qualifiers.CommitterEmail, "committer-email", "", "Filter on committer email") + cmd.Flags().StringVar(&opts.Query.Qualifiers.CommitterName, "committer-name", "", "Filter on committer name") + cmd.Flags().StringVar(&opts.Query.Qualifiers.Hash, "hash", "", "Filter by commit hash") + cmdutil.NilBoolFlag(cmd, &opts.Query.Qualifiers.Merge, "merge", "", "Filter on merge commits") + cmd.Flags().StringVar(&opts.Query.Qualifiers.Parent, "parent", "", "Filter by parent hash") + cmd.Flags().StringSliceVar(&opts.Query.Qualifiers.Repo, "repo", nil, "Filter on repository") + cmd.Flags().StringVar(&opts.Query.Qualifiers.Tree, "tree", "", "Filter by tree hash") + cmd.Flags().StringVar(&opts.Query.Qualifiers.User, "owner", "", "Filter on repository owner") + cmdutil.StringSliceEnumFlag(cmd, &opts.Query.Qualifiers.Is, "visibility", "", nil, []string{"public", "private", "internal"}, "Filter based on repository visibility") + + return cmd +} + +func commitsRun(opts *CommitsOptions) error { + io := opts.IO + if opts.WebMode { + url := opts.Searcher.URL(opts.Query) + if io.IsStdoutTTY() { + fmt.Fprintf(io.ErrOut, "Opening %s in your browser.\n", text.DisplayURL(url)) + } + return opts.Browser.Browse(url) + } + io.StartProgressIndicator() + result, err := opts.Searcher.Commits(opts.Query) + io.StopProgressIndicator() + if err != nil { + return err + } + if len(result.Items) == 0 && opts.Exporter == nil { + return cmdutil.NewNoResultsError("no commits matched your search") + } + if err := io.StartPager(); err == nil { + defer io.StopPager() + } else { + fmt.Fprintf(io.ErrOut, "failed to start pager: %v\n", err) + } + if opts.Exporter != nil { + return opts.Exporter.Write(io, result.Items) + } + + return displayResults(io, opts.Now, result) +} + +func displayResults(io *iostreams.IOStreams, now time.Time, results search.CommitsResult) error { + if now.IsZero() { + now = time.Now() + } + cs := io.ColorScheme() + tp := tableprinter.New(io) + tp.HeaderRow("Repo", "SHA", "Message", "Author", "Created") + for _, commit := range results.Items { + tp.AddField(commit.Repo.FullName) + tp.AddField(commit.Sha) + tp.AddField(text.RemoveExcessiveWhitespace(commit.Info.Message)) + tp.AddField(commit.Author.Login) + tp.AddTimeField(now, commit.Info.Author.Date, cs.Gray) + tp.EndRow() + } + if io.IsStdoutTTY() { + header := fmt.Sprintf("Showing %d of %d commits\n\n", len(results.Items), results.Total) + fmt.Fprintf(io.Out, "\n%s", header) + } + return tp.Render() +} diff --git a/pkg/cmd/search/commits/commits_test.go b/pkg/cmd/search/commits/commits_test.go new file mode 100644 index 000000000..9728f3ee1 --- /dev/null +++ b/pkg/cmd/search/commits/commits_test.go @@ -0,0 +1,311 @@ +package commits + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/cli/cli/v2/internal/browser" + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/iostreams" + "github.com/cli/cli/v2/pkg/search" + "github.com/google/shlex" + "github.com/stretchr/testify/assert" +) + +func TestNewCmdCommits(t *testing.T) { + var trueBool = true + tests := []struct { + name string + input string + output CommitsOptions + wantErr bool + errMsg string + }{ + { + name: "no arguments", + input: "", + wantErr: true, + errMsg: "specify search keywords or flags", + }, + { + name: "keyword arguments", + input: "some search terms", + output: CommitsOptions{ + Query: search.Query{Keywords: []string{"some", "search", "terms"}, Kind: "commits", Limit: 30}, + }, + }, + { + name: "web flag", + input: "--web", + output: CommitsOptions{ + Query: search.Query{Keywords: []string{}, Kind: "commits", Limit: 30}, + WebMode: true, + }, + }, + { + name: "limit flag", + input: "--limit 10", + output: CommitsOptions{Query: search.Query{Keywords: []string{}, Kind: "commits", Limit: 10}}, + }, + { + name: "invalid limit flag", + input: "--limit 1001", + wantErr: true, + errMsg: "`--limit` must be between 1 and 1000", + }, + { + name: "order flag", + input: "--order asc", + output: CommitsOptions{ + Query: search.Query{Keywords: []string{}, Kind: "commits", Limit: 30, Order: "asc"}, + }, + }, + { + name: "invalid order flag", + input: "--order invalid", + wantErr: true, + errMsg: "invalid argument \"invalid\" for \"--order\" flag: valid values are {asc|desc}", + }, + { + name: "qualifier flags", + input: ` + --author=foo + --author-date=01-01-2000 + --author-email=foo@example.com + --author-name=Foo + --committer=bar + --committer-date=01-02-2000 + --committer-email=bar@example.com + --committer-name=Bar + --hash=aaa + --merge + --parent=bbb + --repo=owner/repo + --tree=ccc + --owner=owner + --visibility=public + `, + output: CommitsOptions{ + Query: search.Query{ + Keywords: []string{}, + Kind: "commits", + Limit: 30, + Qualifiers: search.Qualifiers{ + Author: "foo", + AuthorDate: "01-01-2000", + AuthorEmail: "foo@example.com", + AuthorName: "Foo", + Committer: "bar", + CommitterDate: "01-02-2000", + CommitterEmail: "bar@example.com", + CommitterName: "Bar", + Hash: "aaa", + Merge: &trueBool, + Parent: "bbb", + Repo: []string{"owner/repo"}, + Tree: "ccc", + User: "owner", + Is: []string{"public"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + f := &cmdutil.Factory{ + IOStreams: ios, + } + argv, err := shlex.Split(tt.input) + assert.NoError(t, err) + var gotOpts *CommitsOptions + cmd := NewCmdCommits(f, func(opts *CommitsOptions) error { + gotOpts = opts + return nil + }) + cmd.SetArgs(argv) + cmd.SetIn(&bytes.Buffer{}) + cmd.SetOut(&bytes.Buffer{}) + cmd.SetErr(&bytes.Buffer{}) + + _, err = cmd.ExecuteC() + if tt.wantErr { + assert.EqualError(t, err, tt.errMsg) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.output.Query, gotOpts.Query) + assert.Equal(t, tt.output.WebMode, gotOpts.WebMode) + }) + } +} + +func TestCommitsRun(t *testing.T) { + var now = time.Date(2023, 1, 17, 12, 30, 0, 0, time.UTC) + var author = search.CommitUser{Date: time.Date(2022, 12, 27, 11, 30, 0, 0, time.UTC)} + var committer = search.CommitUser{Date: time.Date(2022, 12, 28, 12, 30, 0, 0, time.UTC)} + var query = search.Query{ + Keywords: []string{"cli"}, + Kind: "commits", + Limit: 30, + Qualifiers: search.Qualifiers{}, + } + tests := []struct { + errMsg string + name string + opts *CommitsOptions + tty bool + wantErr bool + wantStderr string + wantStdout string + }{ + { + name: "displays results tty", + opts: &CommitsOptions{ + Query: query, + Searcher: &search.SearcherMock{ + CommitsFunc: func(query search.Query) (search.CommitsResult, error) { + return search.CommitsResult{ + IncompleteResults: false, + Items: []search.Commit{ + { + Author: search.User{Login: "monalisa"}, + Info: search.CommitInfo{Author: author, Committer: committer, Message: "hello"}, + Repo: search.Repository{FullName: "test/cli"}, + Sha: "aaaaaaaa", + }, + { + Author: search.User{Login: "johnnytest"}, + Info: search.CommitInfo{Author: author, Committer: committer, Message: "hi"}, + Repo: search.Repository{FullName: "test/cliing", IsPrivate: true}, + Sha: "bbbbbbbb", + }, + { + Author: search.User{Login: "hubot"}, + Info: search.CommitInfo{Author: author, Committer: committer, Message: "greetings"}, + Repo: search.Repository{FullName: "cli/cli"}, + Sha: "cccccccc", + }, + }, + Total: 300, + }, nil + }, + }, + }, + tty: true, + wantStdout: "\nShowing 3 of 300 commits\n\nREPO SHA MESSAGE AUTHOR CREATED\ntest/cli aaaaaaaa hello monalisa about 21 days ago\ntest/cliing bbbbbbbb hi johnnytest about 21 days ago\ncli/cli cccccccc greetings hubot about 21 days ago\n", + }, + { + name: "displays results notty", + opts: &CommitsOptions{ + Query: query, + Searcher: &search.SearcherMock{ + CommitsFunc: func(query search.Query) (search.CommitsResult, error) { + return search.CommitsResult{ + IncompleteResults: false, + Items: []search.Commit{ + { + Author: search.User{Login: "monalisa"}, + Info: search.CommitInfo{Author: author, Committer: committer, Message: "hello"}, + Repo: search.Repository{FullName: "test/cli"}, + Sha: "aaaaaaaa", + }, + { + Author: search.User{Login: "johnnytest"}, + Info: search.CommitInfo{Author: author, Committer: committer, Message: "hi"}, + Repo: search.Repository{FullName: "test/cliing", IsPrivate: true}, + Sha: "bbbbbbbb", + }, + { + Author: search.User{Login: "hubot"}, + Info: search.CommitInfo{Author: author, Committer: committer, Message: "greetings"}, + Repo: search.Repository{FullName: "cli/cli"}, + Sha: "cccccccc", + }, + }, + Total: 300, + }, nil + }, + }, + }, + wantStdout: "test/cli\taaaaaaaa\thello\tmonalisa\t2022-12-27T11:30:00Z\ntest/cliing\tbbbbbbbb\thi\tjohnnytest\t2022-12-27T11:30:00Z\ncli/cli\tcccccccc\tgreetings\thubot\t2022-12-27T11:30:00Z\n", + }, + { + name: "displays no results", + opts: &CommitsOptions{ + Query: query, + Searcher: &search.SearcherMock{ + CommitsFunc: func(query search.Query) (search.CommitsResult, error) { + return search.CommitsResult{}, nil + }, + }, + }, + wantErr: true, + errMsg: "no commits matched your search", + }, + { + name: "displays search error", + opts: &CommitsOptions{ + Query: query, + Searcher: &search.SearcherMock{ + CommitsFunc: func(query search.Query) (search.CommitsResult, error) { + return search.CommitsResult{}, fmt.Errorf("error with query") + }, + }, + }, + errMsg: "error with query", + wantErr: true, + }, + { + name: "opens browser for web mode tty", + opts: &CommitsOptions{ + Browser: &browser.Stub{}, + Query: query, + Searcher: &search.SearcherMock{ + URLFunc: func(query search.Query) string { + return "https://github.com/search?type=commits&q=cli" + }, + }, + WebMode: true, + }, + tty: true, + wantStderr: "Opening github.com/search in your browser.\n", + }, + { + name: "opens browser for web mode notty", + opts: &CommitsOptions{ + Browser: &browser.Stub{}, + Query: query, + Searcher: &search.SearcherMock{ + URLFunc: func(query search.Query) string { + return "https://github.com/search?type=commits&q=cli" + }, + }, + WebMode: true, + }, + }, + } + for _, tt := range tests { + ios, _, stdout, stderr := iostreams.Test() + ios.SetStdinTTY(tt.tty) + ios.SetStdoutTTY(tt.tty) + ios.SetStderrTTY(tt.tty) + tt.opts.IO = ios + tt.opts.Now = now + t.Run(tt.name, func(t *testing.T) { + err := commitsRun(tt.opts) + if tt.wantErr { + assert.EqualError(t, err, tt.errMsg) + return + } else if err != nil { + t.Fatalf("commitsRun unexpected error: %v", err) + } + assert.Equal(t, tt.wantStdout, stdout.String()) + assert.Equal(t, tt.wantStderr, stderr.String()) + }) + } +} diff --git a/pkg/cmd/search/repos/repos.go b/pkg/cmd/search/repos/repos.go index 25c796c37..a8066bc7b 100644 --- a/pkg/cmd/search/repos/repos.go +++ b/pkg/cmd/search/repos/repos.go @@ -7,12 +7,12 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/internal/browser" + "github.com/cli/cli/v2/internal/tableprinter" "github.com/cli/cli/v2/internal/text" "github.com/cli/cli/v2/pkg/cmd/search/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/pkg/search" - "github.com/cli/cli/v2/utils" "github.com/spf13/cobra" ) @@ -158,8 +158,8 @@ func displayResults(io *iostreams.IOStreams, now time.Time, results search.Repos now = time.Now() } cs := io.ColorScheme() - //nolint:staticcheck // SA1019: utils.NewTablePrinter is deprecated: use internal/tableprinter - tp := utils.NewTablePrinter(io) + tp := tableprinter.New(io) + tp.HeaderRow("Name", "Description", "Visibility", "Updated") for _, repo := range results.Items { tags := []string{visibilityLabel(repo)} if repo.IsFork { @@ -173,15 +173,10 @@ func displayResults(io *iostreams.IOStreams, now time.Time, results search.Repos if repo.IsPrivate { infoColor = cs.Yellow } - tp.AddField(repo.FullName, nil, cs.Bold) - description := repo.Description - tp.AddField(text.RemoveExcessiveWhitespace(description), nil, nil) - tp.AddField(info, nil, infoColor) - if tp.IsTTY() { - tp.AddField(text.FuzzyAgoAbbr(now, repo.UpdatedAt), nil, cs.Gray) - } else { - tp.AddField(repo.UpdatedAt.Format(time.RFC3339), nil, nil) - } + tp.AddField(repo.FullName, tableprinter.WithColor(cs.Bold)) + tp.AddField(text.RemoveExcessiveWhitespace(repo.Description)) + tp.AddField(info, tableprinter.WithColor(infoColor)) + tp.AddTimeField(now, repo.UpdatedAt, cs.Gray) tp.EndRow() } if io.IsStdoutTTY() { diff --git a/pkg/cmd/search/repos/repos_test.go b/pkg/cmd/search/repos/repos_test.go index d0410b196..7b80c9177 100644 --- a/pkg/cmd/search/repos/repos_test.go +++ b/pkg/cmd/search/repos/repos_test.go @@ -188,7 +188,7 @@ func TestReposRun(t *testing.T) { }, }, tty: true, - wantStdout: "\nShowing 3 of 300 repositories\n\ntest/cli of course private, archived Feb 28, 2021\ntest/cliing wow public, fork Feb 28, 2021\ncli/cli so much internal Feb 28, 2021\n", + wantStdout: "\nShowing 3 of 300 repositories\n\nNAME DESCRIPTION VISIBILITY UPDATED\ntest/cli of course private, archived about 1 year ago\ntest/cliing wow public, fork about 1 year ago\ncli/cli so much internal about 1 year ago\n", }, { name: "displays results notty", diff --git a/pkg/cmd/search/search.go b/pkg/cmd/search/search.go index 188981670..7b9a4a653 100644 --- a/pkg/cmd/search/search.go +++ b/pkg/cmd/search/search.go @@ -4,6 +4,7 @@ import ( "github.com/cli/cli/v2/pkg/cmdutil" "github.com/spf13/cobra" + searchCommitsCmd "github.com/cli/cli/v2/pkg/cmd/search/commits" searchIssuesCmd "github.com/cli/cli/v2/pkg/cmd/search/issues" searchPrsCmd "github.com/cli/cli/v2/pkg/cmd/search/prs" searchReposCmd "github.com/cli/cli/v2/pkg/cmd/search/repos" @@ -16,6 +17,7 @@ func NewCmdSearch(f *cmdutil.Factory) *cobra.Command { Long: "Search across all of GitHub.", } + cmd.AddCommand(searchCommitsCmd.NewCmdCommits(f, nil)) cmd.AddCommand(searchIssuesCmd.NewCmdIssues(f, nil)) cmd.AddCommand(searchPrsCmd.NewCmdPrs(f, nil)) cmd.AddCommand(searchReposCmd.NewCmdRepos(f, nil)) diff --git a/pkg/cmd/search/shared/shared.go b/pkg/cmd/search/shared/shared.go index a5cdfe8dc..e751848d4 100644 --- a/pkg/cmd/search/shared/shared.go +++ b/pkg/cmd/search/shared/shared.go @@ -7,11 +7,11 @@ import ( "time" "github.com/cli/cli/v2/internal/browser" + "github.com/cli/cli/v2/internal/tableprinter" "github.com/cli/cli/v2/internal/text" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/pkg/search" - "github.com/cli/cli/v2/utils" ) type EntityType int @@ -95,43 +95,46 @@ func displayIssueResults(io *iostreams.IOStreams, now time.Time, et EntityType, if now.IsZero() { now = time.Now() } + isTTY := io.IsStdoutTTY() cs := io.ColorScheme() - //nolint:staticcheck // SA1019: utils.NewTablePrinter is deprecated: use internal/tableprinter - tp := utils.NewTablePrinter(io) + tp := tableprinter.New(io) + if et == Both { + tp.HeaderRow("Kind", "Repo", "ID", "Title", "Labels", "Updated") + } else { + tp.HeaderRow("Repo", "ID", "Title", "Labels", "Updated") + } for _, issue := range results.Items { if et == Both { kind := "issue" if issue.IsPullRequest() { kind = "pr" } - tp.AddField(kind, nil, nil) + tp.AddField(kind) } comp := strings.Split(issue.RepositoryURL, "/") name := comp[len(comp)-2:] - tp.AddField(strings.Join(name, "/"), nil, nil) + tp.AddField(strings.Join(name, "/")) issueNum := strconv.Itoa(issue.Number) - if tp.IsTTY() { + if isTTY { issueNum = "#" + issueNum } if issue.IsPullRequest() { - tp.AddField(issueNum, nil, cs.ColorFromString(colorForPRState(issue.State()))) + color := tableprinter.WithColor(cs.ColorFromString(colorForPRState(issue.State()))) + tp.AddField(issueNum, color) } else { - tp.AddField(issueNum, nil, cs.ColorFromString(colorForIssueState(issue.State(), issue.StateReason))) + color := tableprinter.WithColor(cs.ColorFromString(colorForIssueState(issue.State(), issue.StateReason))) + tp.AddField(issueNum, color) } - if !tp.IsTTY() { - tp.AddField(issue.State(), nil, nil) - } - tp.AddField(text.RemoveExcessiveWhitespace(issue.Title), nil, nil) - tp.AddField(listIssueLabels(&issue, cs, tp.IsTTY()), nil, nil) - if tp.IsTTY() { - tp.AddField(text.FuzzyAgo(now, issue.UpdatedAt), nil, cs.Gray) - } else { - tp.AddField(issue.UpdatedAt.String(), nil, nil) + if !isTTY { + tp.AddField(issue.State()) } + tp.AddField(text.RemoveExcessiveWhitespace(issue.Title)) + tp.AddField(listIssueLabels(&issue, cs, isTTY)) + tp.AddTimeField(now, issue.UpdatedAt, cs.Gray) tp.EndRow() } - if io.IsStdoutTTY() { + if isTTY { var header string switch et { case Both: diff --git a/pkg/cmd/search/shared/shared_test.go b/pkg/cmd/search/shared/shared_test.go index 8c4d4ca45..9d977cc83 100644 --- a/pkg/cmd/search/shared/shared_test.go +++ b/pkg/cmd/search/shared/shared_test.go @@ -64,7 +64,7 @@ func TestSearchIssues(t *testing.T) { }, }, tty: true, - wantStdout: "\nShowing 3 of 300 issues\n\ntest/cli #123 something broken bug, p1 about 1 year ago\nwhat/what #456 feature request enhancement about 1 year ago\nblah/test #789 some title about 1 year ago\n", + wantStdout: "\nShowing 3 of 300 issues\n\nREPO ID TITLE LABELS UPDATED\ntest/cli #123 something broken bug, p1 about 1 year ago\nwhat/what #456 feature request enhancement about 1 year ago\nblah/test #789 some title about 1 year ago\n", }, { name: "displays issues and pull requests tty", @@ -85,7 +85,7 @@ func TestSearchIssues(t *testing.T) { }, }, tty: true, - wantStdout: "\nShowing 2 of 300 issues and pull requests\n\nissue test/cli #123 bug bug, p1 about 1 year ago\npr what/what #456 fix bug fix about 1 year ago\n", + wantStdout: "\nShowing 2 of 300 issues and pull requests\n\nKIND REPO ID TITLE LABELS UPDATED\nissue test/cli #123 bug bug, p1 about 1 year ago\npr what/what #456 fix bug fix about 1 year ago\n", }, { name: "displays results notty", @@ -106,7 +106,7 @@ func TestSearchIssues(t *testing.T) { }, }, }, - wantStdout: "test/cli\t123\topen\tsomething broken\tbug, p1\t2021-02-28 12:30:00 +0000 UTC\nwhat/what\t456\tclosed\tfeature request\tenhancement\t2021-02-28 12:30:00 +0000 UTC\nblah/test\t789\topen\tsome title\t\t2021-02-28 12:30:00 +0000 UTC\n", + wantStdout: "test/cli\t123\topen\tsomething broken\tbug, p1\t2021-02-28T12:30:00Z\nwhat/what\t456\tclosed\tfeature request\tenhancement\t2021-02-28T12:30:00Z\nblah/test\t789\topen\tsome title\t\t2021-02-28T12:30:00Z\n", }, { name: "displays issues and pull requests notty", @@ -126,7 +126,7 @@ func TestSearchIssues(t *testing.T) { }, }, }, - wantStdout: "issue\ttest/cli\t123\topen\tbug\tbug, p1\t2021-02-28 12:30:00 +0000 UTC\npr\twhat/what\t456\topen\tfix bug\tfix\t2021-02-28 12:30:00 +0000 UTC\n", + wantStdout: "issue\ttest/cli\t123\topen\tbug\tbug, p1\t2021-02-28T12:30:00Z\npr\twhat/what\t456\topen\tfix bug\tfix\t2021-02-28T12:30:00Z\n", }, { name: "displays no results", diff --git a/pkg/cmd/secret/set/http.go b/pkg/cmd/secret/set/http.go index f36c2b59e..d5c2bf436 100644 --- a/pkg/cmd/secret/set/http.go +++ b/pkg/cmd/secret/set/http.go @@ -20,9 +20,9 @@ type SecretPayload struct { KeyID string `json:"key_id"` } -// The Codespaces Secret API currently expects repositories IDs as strings -type CodespacesSecretPayload struct { +type DependabotSecretPayload struct { EncryptedValue string `json:"encrypted_value"` + Visibility string `json:"visibility,omitempty"` Repositories []string `json:"selected_repository_ids,omitempty"` KeyID string `json:"key_id"` } @@ -70,31 +70,40 @@ func putSecret(client *api.Client, host, path string, payload interface{}) error } func putOrgSecret(client *api.Client, host string, pk *PubKey, orgName, visibility, secretName, eValue string, repositoryIDs []int64, app shared.App) error { + path := fmt.Sprintf("orgs/%s/%s/secrets/%s", orgName, app, secretName) + + if app == shared.Dependabot { + repos := make([]string, len(repositoryIDs)) + for i, id := range repositoryIDs { + repos[i] = strconv.FormatInt(id, 10) + } + + payload := DependabotSecretPayload{ + EncryptedValue: eValue, + KeyID: pk.ID, + Repositories: repos, + Visibility: visibility, + } + + return putSecret(client, host, path, payload) + } + payload := SecretPayload{ EncryptedValue: eValue, KeyID: pk.ID, Repositories: repositoryIDs, Visibility: visibility, } - path := fmt.Sprintf("orgs/%s/%s/secrets/%s", orgName, app, secretName) return putSecret(client, host, path, payload) } func putUserSecret(client *api.Client, host string, pk *PubKey, key, eValue string, repositoryIDs []int64) error { - payload := CodespacesSecretPayload{ + payload := SecretPayload{ EncryptedValue: eValue, KeyID: pk.ID, + Repositories: repositoryIDs, } - - if len(repositoryIDs) > 0 { - repositoryStringIDs := make([]string, len(repositoryIDs)) - for i, id := range repositoryIDs { - repositoryStringIDs[i] = strconv.FormatInt(id, 10) - } - payload.Repositories = repositoryStringIDs - } - path := fmt.Sprintf("user/codespaces/secrets/%s", key) return putSecret(client, host, path, payload) } diff --git a/pkg/cmd/secret/set/set_test.go b/pkg/cmd/secret/set/set_test.go index d9313f81f..592527c5f 100644 --- a/pkg/cmd/secret/set/set_test.go +++ b/pkg/cmd/secret/set/set_test.go @@ -333,11 +333,12 @@ func Test_setRun_env(t *testing.T) { func Test_setRun_org(t *testing.T) { tests := []struct { - name string - opts *SetOptions - wantVisibility shared.Visibility - wantRepositories []int64 - wantApp string + name string + opts *SetOptions + wantVisibility shared.Visibility + wantRepositories []int64 + wantDependabotRepositories []string + wantApp string }{ { name: "all vis", @@ -362,10 +363,21 @@ func Test_setRun_org(t *testing.T) { opts: &SetOptions{ OrgName: "UmbrellaCorporation", Visibility: shared.All, - Application: "dependabot", + Application: shared.Dependabot, }, wantApp: "dependabot", }, + { + name: "Dependabot selected visibility", + opts: &SetOptions{ + OrgName: "UmbrellaCorporation", + Visibility: shared.Selected, + Application: shared.Dependabot, + RepositoryNames: []string{"birkin", "UmbrellaCorporation/wesker"}, + }, + wantDependabotRepositories: []string{"1", "2"}, + wantApp: "dependabot", + }, } for _, tt := range tests { @@ -410,13 +422,24 @@ func Test_setRun_org(t *testing.T) { data, err := io.ReadAll(reg.Requests[len(reg.Requests)-1].Body) assert.NoError(t, err) - var payload SecretPayload - err = json.Unmarshal(data, &payload) - assert.NoError(t, err) - assert.Equal(t, payload.KeyID, "123") - assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=") - assert.Equal(t, payload.Visibility, tt.opts.Visibility) - assert.ElementsMatch(t, payload.Repositories, tt.wantRepositories) + + if tt.opts.Application == shared.Dependabot { + var payload DependabotSecretPayload + err = json.Unmarshal(data, &payload) + assert.NoError(t, err) + assert.Equal(t, payload.KeyID, "123") + assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=") + assert.Equal(t, payload.Visibility, tt.opts.Visibility) + assert.ElementsMatch(t, payload.Repositories, tt.wantDependabotRepositories) + } else { + var payload SecretPayload + err = json.Unmarshal(data, &payload) + assert.NoError(t, err) + assert.Equal(t, payload.KeyID, "123") + assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=") + assert.Equal(t, payload.Visibility, tt.opts.Visibility) + assert.ElementsMatch(t, payload.Repositories, tt.wantRepositories) + } }) } } @@ -426,7 +449,7 @@ func Test_setRun_user(t *testing.T) { name string opts *SetOptions wantVisibility shared.Visibility - wantRepositories []string + wantRepositories []int64 }{ { name: "all vis", @@ -442,7 +465,7 @@ func Test_setRun_user(t *testing.T) { Visibility: shared.Selected, RepositoryNames: []string{"cli/cli", "github/hub"}, }, - wantRepositories: []string{"212613049", "401025"}, + wantRepositories: []int64{212613049, 401025}, }, } @@ -481,7 +504,7 @@ func Test_setRun_user(t *testing.T) { data, err := io.ReadAll(reg.Requests[len(reg.Requests)-1].Body) assert.NoError(t, err) - var payload CodespacesSecretPayload + var payload SecretPayload err = json.Unmarshal(data, &payload) assert.NoError(t, err) assert.Equal(t, payload.KeyID, "123") diff --git a/pkg/cmd/ssh-key/delete/delete.go b/pkg/cmd/ssh-key/delete/delete.go index a11cbd769..de53f1391 100644 --- a/pkg/cmd/ssh-key/delete/delete.go +++ b/pkg/cmd/ssh-key/delete/delete.go @@ -37,17 +37,21 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co opts.KeyID = args[0] if !opts.IO.CanPrompt() && !opts.Confirmed { - return cmdutil.FlagErrorf("--confirm required when not running interactively") + return cmdutil.FlagErrorf("--yes required when not running interactively") } if runF != nil { return runF(opts) } + return deleteRun(opts) }, } - cmd.Flags().BoolVarP(&opts.Confirmed, "confirm", "y", false, "Skip the confirmation prompt") + cmd.Flags().BoolVar(&opts.Confirmed, "confirm", false, "Skip the confirmation prompt") + _ = cmd.Flags().MarkDeprecated("confirm", "use `--yes` instead") + cmd.Flags().BoolVarP(&opts.Confirmed, "yes", "y", false, "Skip the confirmation prompt") + return cmd } diff --git a/pkg/cmd/ssh-key/delete/delete_test.go b/pkg/cmd/ssh-key/delete/delete_test.go index 437443c55..85e79de3a 100644 --- a/pkg/cmd/ssh-key/delete/delete_test.go +++ b/pkg/cmd/ssh-key/delete/delete_test.go @@ -32,7 +32,7 @@ func TestNewCmdDelete(t *testing.T) { { name: "confirm flag tty", tty: true, - input: "123 --confirm", + input: "123 --yes", output: DeleteOptions{KeyID: "123", Confirmed: true}, }, { @@ -45,11 +45,11 @@ func TestNewCmdDelete(t *testing.T) { name: "no tty", input: "123", wantErr: true, - wantErrMsg: "--confirm required when not running interactively", + wantErrMsg: "--yes required when not running interactively", }, { name: "confirm flag no tty", - input: "123 --confirm", + input: "123 --yes", output: DeleteOptions{KeyID: "123", Confirmed: true}, }, { diff --git a/pkg/httpmock/stub.go b/pkg/httpmock/stub.go index cae241516..031db449e 100644 --- a/pkg/httpmock/stub.go +++ b/pkg/httpmock/stub.go @@ -24,6 +24,10 @@ func MatchAny(*http.Request) bool { return true } +// REST returns a matcher to a request for the HTTP method and URL escaped path p. +// For example, to match a GET request to `/api/v3/repos/octocat/hello-world/` +// use REST("GET", "api/v3/repos/octocat/hello-world") +// To match a GET request to `/user` use REST("GET", "user") func REST(method, p string) Matcher { return func(req *http.Request) bool { if !strings.EqualFold(req.Method, method) { diff --git a/pkg/liveshare/client.go b/pkg/liveshare/client.go index 570e8615c..cbfa5d458 100644 --- a/pkg/liveshare/client.go +++ b/pkg/liveshare/client.go @@ -17,7 +17,6 @@ import ( "fmt" "net/url" "strings" - "time" "github.com/opentracing/opentracing-go" ) @@ -29,7 +28,6 @@ type logger interface { // An Options specifies Live Share connection parameters. type Options struct { - ClientName string // ClientName is the name of the connecting client. SessionID string SessionToken string // token for SSH session RelaySAS string @@ -41,9 +39,6 @@ type Options struct { // uri returns a websocket URL for the specified options. func (opts *Options) uri(action string) (string, error) { - if opts.ClientName == "" { - return "", errors.New("ClientName is required") - } if opts.SessionID == "" { return "", errors.New("SessionID is required") } @@ -112,11 +107,9 @@ func Connect(ctx context.Context, opts Options) (*Session, error) { s := &Session{ ssh: ssh, rpc: rpc, - clientName: opts.ClientName, keepAliveReason: make(chan string, 1), logger: opts.Logger, } - go s.heartbeat(ctx, 1*time.Minute) return s, nil } diff --git a/pkg/liveshare/client_test.go b/pkg/liveshare/client_test.go index 4b2908858..a39c48f7b 100644 --- a/pkg/liveshare/client_test.go +++ b/pkg/liveshare/client_test.go @@ -15,7 +15,6 @@ import ( func TestConnect(t *testing.T) { opts := Options{ - ClientName: "liveshare-client", SessionID: "session-id", SessionToken: "session-token", RelaySAS: "relay-sas", diff --git a/pkg/liveshare/options_test.go b/pkg/liveshare/options_test.go index d244193b4..830c59104 100644 --- a/pkg/liveshare/options_test.go +++ b/pkg/liveshare/options_test.go @@ -41,7 +41,6 @@ func checkBadOptions(t *testing.T, opts Options) { func TestOptionsURI(t *testing.T) { opts := Options{ - ClientName: "liveshare-client", SessionID: "sess-id", SessionToken: "sess-token", RelaySAS: "sas", diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go index f042eeaea..5f2742209 100644 --- a/pkg/liveshare/port_forwarder.go +++ b/pkg/liveshare/port_forwarder.go @@ -16,6 +16,33 @@ type portForwardingSession interface { KeepAlive(string) } +type ReadWriteHalfCloser interface { + io.ReadWriteCloser + CloseWrite() error +} + +type combinedReadWriteHalfCloser struct { + io.ReadCloser + io.WriteCloser +} + +func NewReadWriteHalfCloser(reader io.ReadCloser, writer io.WriteCloser) ReadWriteHalfCloser { + return &combinedReadWriteHalfCloser{reader, writer} +} + +func (crwc *combinedReadWriteHalfCloser) Close() error { + werr := crwc.WriteCloser.Close() + rerr := crwc.ReadCloser.Close() + if werr != nil { + return werr + } + return rerr +} + +func (crwc *combinedReadWriteHalfCloser) CloseWrite() error { + return crwc.WriteCloser.Close() +} + // A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote // container to a local destination such as a network port or Go reader/writer. type PortForwarder struct { @@ -48,7 +75,7 @@ func NewPortForwarder(session portForwardingSession, name string, remotePort int // until it encounters the first error, which may include context // cancellation. Its error result is always non-nil. The caller is // responsible for closing the listening port. -func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) { +func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen *net.TCPListener) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { return err @@ -65,7 +92,7 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List } go func() { for { - conn, err := listen.Accept() + conn, err := listen.AcceptTCP() if err != nil { sendError(err) return @@ -84,7 +111,7 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List // Forward forwards traffic between the container's remote port and // the specified read/write stream. On return, the stream is closed. -func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error { +func (fwd *PortForwarder) Forward(ctx context.Context, conn ReadWriteHalfCloser) error { id, err := fwd.shareRemotePort(ctx) if err != nil { conn.Close() @@ -143,7 +170,7 @@ func (t *trafficMonitor) Read(p []byte) (n int, err error) { } // handleConnection handles forwarding for a single accepted connection, then closes it. -func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn io.ReadWriteCloser) (err error) { +func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn ReadWriteHalfCloser) (err error) { span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection") defer span.Finish() @@ -165,9 +192,12 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, co // bi-directional copy of data. errs := make(chan error, 2) - copyConn := func(w io.Writer, r io.Reader) { + copyConn := func(w ReadWriteHalfCloser, r io.Reader) { _, err := io.Copy(w, r) errs <- err + + // Ignore errors here, we call the full Close() later and catch that error + _ = w.CloseWrite() } var ( diff --git a/pkg/liveshare/port_forwarder_test.go b/pkg/liveshare/port_forwarder_test.go index b02165849..61acde368 100644 --- a/pkg/liveshare/port_forwarder_test.go +++ b/pkg/liveshare/port_forwarder_test.go @@ -71,6 +71,10 @@ func TestPortForwarderStart(t *testing.T) { t.Fatal(err) } defer listen.Close() + tcpListener, ok := listen.(*net.TCPListener) + if !ok { + t.Fatal("net.Listen did not return a TCPListener") + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -82,7 +86,7 @@ func TestPortForwarderStart(t *testing.T) { done := make(chan error, 2) go func() { - done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, listen) + done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, tcpListener) }() go func() { diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index 697659021..e5ec86703 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -3,7 +3,6 @@ package liveshare import ( "context" "fmt" - "time" "github.com/opentracing/opentracing-go" "golang.org/x/crypto/ssh" @@ -23,6 +22,7 @@ type LiveshareSession interface { KeepAlive(string) OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error) StartSharing(context.Context, string, int) (ChannelID, error) + GetKeepAliveReason() string } // A Session represents the session between a connected Live Share client and server. @@ -30,7 +30,6 @@ type Session struct { ssh *sshSession rpc *rpcClient - clientName string keepAliveReason chan string logger logger } @@ -48,42 +47,17 @@ func (s *Session) Close() error { return nil } +// Fetches the keep alive reason from the channel and returns it. +func (s *Session) GetKeepAliveReason() string { + return <-s.keepAliveReason +} + // registerRequestHandler registers a handler for the given request type with the RPC // server and returns a callback function to deregister the handler func (s *Session) registerRequestHandler(requestType string, h handler) func() { return s.rpc.register(requestType, h) } -// heartbeat runs until context cancellation, periodically checking whether there is a -// reason to keep the connection alive, and if so, notifying the Live Share host to do so. -// Heartbeat ensures it does not send more than one request every "interval" to ratelimit -// how many KeepAlives we send at a time. -func (s *Session) heartbeat(ctx context.Context, interval time.Duration) { - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - s.logger.Println("Heartbeat tick") - reason := <-s.keepAliveReason - s.logger.Println("Keep alive reason: " + reason) - if err := s.notifyHostOfActivity(ctx, reason); err != nil { - s.logger.Printf("Failed to notify host of activity: %s\n", err) - } - } - } -} - -// notifyHostOfActivity notifies the Live Share host of client activity. -func (s *Session) notifyHostOfActivity(ctx context.Context, activity string) error { - activities := []string{activity} - params := []interface{}{s.clientName, activities} - return s.rpc.do(ctx, "ICodespaceHostService.notifyCodespaceOfClientActivity", params, nil) -} - // KeepAlive accepts a reason that is retained if there is no active reason // to send to the server. func (s *Session) KeepAlive(reason string) { diff --git a/pkg/liveshare/session_test.go b/pkg/liveshare/session_test.go index cfe8ccd11..9de50ff92 100644 --- a/pkg/liveshare/session_test.go +++ b/pkg/liveshare/session_test.go @@ -10,14 +10,11 @@ import ( "strings" "sync" "testing" - "time" livesharetest "github.com/cli/cli/v2/pkg/liveshare/test" "github.com/sourcegraph/jsonrpc2" ) -const mockClientName = "liveshare-client" - func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { return joinWorkspaceResult{1}, nil @@ -34,7 +31,6 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, } session, err := Connect(context.Background(), Options{ - ClientName: mockClientName, SessionID: "session-id", SessionToken: sessionToken, RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), @@ -254,151 +250,6 @@ func TestKeepAliveNonBlocking(t *testing.T) { // timing out } -func TestNotifyHostOfActivity(t *testing.T) { - notifyHostOfActivity := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) { - var req []interface{} - if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { - return nil, fmt.Errorf("unmarshal req: %w", err) - } - if len(req) < 2 { - return nil, errors.New("request arguments is less than 2") - } - - if clientName, ok := req[0].(string); ok { - if clientName != mockClientName { - return nil, fmt.Errorf( - "unexpected clientName param, expected: %q, got: %q", mockClientName, clientName, - ) - } - } else { - return nil, errors.New("clientName param is not a string") - } - - if acs, ok := req[1].([]interface{}); ok { - if fmt.Sprintf("%s", acs) != "[input]" { - return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs) - } - } else { - return nil, errors.New("activities param is not a slice") - } - - return nil, nil - } - svc := livesharetest.WithService( - "ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity, - ) - testServer, session, err := makeMockSession(svc) - if err != nil { - t.Fatalf("creating mock session: %v", err) - } - defer testServer.Close() - ctx := context.Background() - done := make(chan error) - go func() { - done <- session.notifyHostOfActivity(ctx, "input") - }() - select { - case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) - case err := <-done: - if err != nil { - t.Errorf("error from client: %v", err) - } - } -} - -func TestSessionHeartbeat(t *testing.T) { - var ( - requestsMu sync.Mutex - requests int - wg sync.WaitGroup - ) - wg.Add(1) - notifyHostOfActivity := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) { - defer wg.Done() - requestsMu.Lock() - requests++ - requestsMu.Unlock() - - var req []interface{} - if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { - return nil, fmt.Errorf("unmarshal req: %w", err) - } - if len(req) < 2 { - return nil, errors.New("request arguments is less than 2") - } - - if clientName, ok := req[0].(string); ok { - if clientName != mockClientName { - return nil, fmt.Errorf( - "unexpected clientName param, expected: %q, got: %q", mockClientName, clientName, - ) - } - } else { - return nil, errors.New("clientName param is not a string") - } - - if acs, ok := req[1].([]interface{}); ok { - if fmt.Sprintf("%s", acs) != "[input]" { - return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs) - } - } else { - return nil, errors.New("activities param is not a slice") - } - - return nil, nil - } - svc := livesharetest.WithService( - "ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity, - ) - testServer, session, err := makeMockSession(svc) - if err != nil { - t.Fatalf("creating mock session: %v", err) - } - defer testServer.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - done := make(chan struct{}) - - logger := newMockLogger() - session.logger = logger - - go session.heartbeat(ctx, 50*time.Millisecond) - go func() { - session.KeepAlive("input") - wg.Wait() - wg.Add(1) - session.KeepAlive("input") - wg.Wait() - done <- struct{}{} - }() - - select { - case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) - case <-done: - activityCount := strings.Count(logger.String(), "input") - // by design KeepAlive can drop requests, and therefore there is zero guarantee - // that we actually get two requests if the network happened to be slow (rarely) - // during testing. - if activityCount != 1 && activityCount != 2 { - t.Errorf("unexpected number of activities, expected: 1-2, got: %d", activityCount) - } - - requestsMu.Lock() - rc := requests - requestsMu.Unlock() - // though this could be also dropped, the sync.WaitGroup above guarantees - // that it gets called a second time. - if rc != 2 { - t.Errorf("unexpected number of requests, expected: 2, got: %d", requests) - } - return - } -} - type mockLogger struct { sync.Mutex buf *bytes.Buffer diff --git a/pkg/search/query.go b/pkg/search/query.go index f93e1a46f..192750f35 100644 --- a/pkg/search/query.go +++ b/pkg/search/query.go @@ -11,6 +11,7 @@ import ( const ( KindRepositories = "repositories" KindIssues = "issues" + KindCommits = "commits" ) type Query struct { @@ -27,16 +28,24 @@ type Qualifiers struct { Archived *bool Assignee string Author string + AuthorDate string + AuthorEmail string + AuthorName string Base string Closed string Commenter string Comments string + Committer string + CommitterDate string + CommitterEmail string + CommitterName string Created string Draft *bool Followers string Fork string Forks string GoodFirstIssues string + Hash string Head string HelpWantedIssues string In []string @@ -47,9 +56,11 @@ type Qualifiers struct { Language string License []string Mentions string + Merge *bool Merged string Milestone string No []string + Parent string Project string Pushed string Reactions string @@ -65,6 +76,7 @@ type Qualifiers struct { TeamReviewRequested string Topic []string Topics string + Tree string Type string Updated string User string diff --git a/pkg/search/query_test.go b/pkg/search/query_test.go index 27f201508..c2b2d8605 100644 --- a/pkg/search/query_test.go +++ b/pkg/search/query_test.go @@ -20,6 +20,8 @@ func TestQueryString(t *testing.T) { Keywords: []string{"some", "keywords"}, Qualifiers: Qualifiers{ Archived: &trueBool, + AuthorEmail: "foo@example.com", + CommitterDate: "2021-02-28", Created: "created", Followers: "1", Fork: "true", @@ -38,7 +40,7 @@ func TestQueryString(t *testing.T) { Is: []string{"public"}, }, }, - out: "some keywords archived:true created:created followers:1 fork:true forks:2 good-first-issues:3 help-wanted-issues:4 in:description in:readme is:public language:language license:license pushed:updated size:5 stars:6 topic:topic topics:7 user:user", + out: "some keywords archived:true author-email:foo@example.com committer-date:2021-02-28 created:created followers:1 fork:true forks:2 good-first-issues:3 help-wanted-issues:4 in:description in:readme is:public language:language license:license pushed:updated size:5 stars:6 topic:topic topics:7 user:user", }, { name: "quotes keywords", @@ -74,6 +76,8 @@ func TestQualifiersMap(t *testing.T) { name: "changes qualifiers to map", qualifiers: Qualifiers{ Archived: &trueBool, + AuthorEmail: "foo@example.com", + CommitterDate: "2021-02-28", Created: "created", Followers: "1", Fork: "true", @@ -93,6 +97,8 @@ func TestQualifiersMap(t *testing.T) { }, out: map[string][]string{ "archived": {"true"}, + "author-email": {"foo@example.com"}, + "committer-date": {"2021-02-28"}, "created": {"created"}, "followers": {"1"}, "fork": {"true"}, diff --git a/pkg/search/result.go b/pkg/search/result.go index 8c2e3aa39..d7113bfda 100644 --- a/pkg/search/result.go +++ b/pkg/search/result.go @@ -6,6 +6,17 @@ import ( "time" ) +var CommitFields = []string{ + "author", + "commit", + "committer", + "sha", + "id", + "parents", + "repository", + "url", +} + var RepositoryFields = []string{ "createdAt", "defaultBranch", @@ -61,6 +72,12 @@ var PullRequestFields = append(IssueFields, "isDraft", ) +type CommitsResult struct { + IncompleteResults bool `json:"incomplete_results"` + Items []Commit `json:"items"` + Total int `json:"total_count"` +} + type RepositoriesResult struct { IncompleteResults bool `json:"incomplete_results"` Items []Repository `json:"items"` @@ -73,6 +90,40 @@ type IssuesResult struct { Total int `json:"total_count"` } +type Commit struct { + Author User `json:"author"` + Committer User `json:"committer"` + ID string `json:"node_id"` + Info CommitInfo `json:"commit"` + Parents []Parent `json:"parents"` + Repo Repository `json:"repository"` + Sha string `json:"sha"` + URL string `json:"html_url"` +} + +type CommitInfo struct { + Author CommitUser `json:"author"` + CommentCount int `json:"comment_count"` + Committer CommitUser `json:"committer"` + Message string `json:"message"` + Tree Tree `json:"tree"` +} + +type CommitUser struct { + Date time.Time `json:"date"` + Email string `json:"email"` + Name string `json:"name"` +} + +type Tree struct { + Sha string `json:"sha"` +} + +type Parent struct { + Sha string `json:"sha"` + URL string `json:"html_url"` +} + type Repository struct { CreatedAt time.Time `json:"created_at"` DefaultBranch string `json:"default_branch"` @@ -120,13 +171,6 @@ type User struct { URL string `json:"html_url"` } -func (u *User) IsBot() bool { - // copied from api/queries_issue.go - // would ideally be shared, but it would require coordinating a "user" - // abstraction in a bunch of places. - return u.ID == "" -} - type Issue struct { Assignees []User `json:"assignees"` Author User `json:"user"` @@ -157,18 +201,6 @@ type PullRequest struct { MergedAt time.Time `json:"merged_at"` } -// the state of an issue or a pull request, -// may be either open or closed. -// for a pull request, the "merged" state is -// inferred from a value for merged_at and -// which we take return instead of the "closed" state. -func (issue Issue) State() string { - if !issue.PullRequest.MergedAt.IsZero() { - return "merged" - } - return issue.StateInternal -} - type Label struct { Color string `json:"color"` Description string `json:"description"` @@ -176,6 +208,83 @@ type Label struct { Name string `json:"name"` } +func (u User) IsBot() bool { + // copied from api/queries_issue.go + // would ideally be shared, but it would require coordinating a "user" + // abstraction in a bunch of places. + return u.ID == "" +} + +func (u User) ExportData() map[string]interface{} { + isBot := u.IsBot() + login := u.Login + if isBot { + login = "app/" + login + } + return map[string]interface{}{ + "id": u.ID, + "login": login, + "type": u.Type, + "url": u.URL, + "is_bot": isBot, + } +} + +func (commit Commit) ExportData(fields []string) map[string]interface{} { + v := reflect.ValueOf(commit) + data := map[string]interface{}{} + for _, f := range fields { + switch f { + case "author": + data[f] = commit.Author.ExportData() + case "commit": + info := commit.Info + data[f] = map[string]interface{}{ + "author": map[string]interface{}{ + "date": info.Author.Date, + "email": info.Author.Email, + "name": info.Author.Name, + }, + "committer": map[string]interface{}{ + "date": info.Committer.Date, + "email": info.Committer.Email, + "name": info.Committer.Name, + }, + "comment_count": info.CommentCount, + "message": info.Message, + "tree": map[string]interface{}{"sha": info.Tree.Sha}, + } + case "committer": + data[f] = commit.Committer.ExportData() + case "parents": + parents := make([]interface{}, 0, len(commit.Parents)) + for _, parent := range commit.Parents { + parents = append(parents, map[string]interface{}{ + "sha": parent.Sha, + "url": parent.URL, + }) + } + data[f] = parents + case "repository": + repo := commit.Repo + data[f] = map[string]interface{}{ + "description": repo.Description, + "fullName": repo.FullName, + "name": repo.Name, + "id": repo.ID, + "isFork": repo.IsFork, + "isPrivate": repo.IsPrivate, + "owner": repo.Owner.ExportData(), + "url": repo.URL, + } + default: + sf := fieldByName(v, f) + data[f] = sf.Interface() + } + } + return data +} + func (repo Repository) ExportData(fields []string) map[string]interface{} { v := reflect.ValueOf(repo) data := map[string]interface{}{} @@ -188,12 +297,7 @@ func (repo Repository) ExportData(fields []string) map[string]interface{} { "url": repo.License.URL, } case "owner": - data[f] = map[string]interface{}{ - "id": repo.Owner.ID, - "login": repo.Owner.Login, - "type": repo.Owner.Type, - "url": repo.Owner.URL, - } + data[f] = repo.Owner.ExportData() default: sf := fieldByName(v, f) data[f] = sf.Interface() @@ -202,6 +306,16 @@ func (repo Repository) ExportData(fields []string) map[string]interface{} { return data } +// The state of an issue or a pull request, may be either open or closed. +// For a pull request, the "merged" state is inferred from a value for merged_at and +// which we take return instead of the "closed" state. +func (issue Issue) State() string { + if !issue.PullRequest.MergedAt.IsZero() { + return "merged" + } + return issue.StateInternal +} + func (issue Issue) IsPullRequest() bool { return issue.PullRequest.URL != "" } @@ -214,31 +328,11 @@ func (issue Issue) ExportData(fields []string) map[string]interface{} { case "assignees": assignees := make([]interface{}, 0, len(issue.Assignees)) for _, assignee := range issue.Assignees { - isBot := assignee.IsBot() - login := assignee.Login - if isBot { - login = "app/" + login - } - assignees = append(assignees, map[string]interface{}{ - "id": assignee.ID, - "login": login, - "type": assignee.Type, - "is_bot": isBot, - }) + assignees = append(assignees, assignee.ExportData()) } data[f] = assignees case "author": - isBot := issue.Author.IsBot() - login := issue.Author.Login - if isBot { - login = "app/" + login - } - data[f] = map[string]interface{}{ - "id": issue.Author.ID, - "login": login, - "type": issue.Author.Type, - "is_bot": isBot, - } + data[f] = issue.Author.ExportData() case "isPullRequest": data[f] = issue.IsPullRequest() case "labels": diff --git a/pkg/search/result_test.go b/pkg/search/result_test.go index 756e3b908..cdf424b0b 100644 --- a/pkg/search/result_test.go +++ b/pkg/search/result_test.go @@ -11,6 +11,42 @@ import ( "github.com/stretchr/testify/require" ) +func TestCommitExportData(t *testing.T) { + var authoredAt = time.Date(2021, 2, 27, 11, 30, 0, 0, time.UTC) + var committedAt = time.Date(2021, 2, 28, 12, 30, 0, 0, time.UTC) + tests := []struct { + name string + fields []string + commit Commit + output string + }{ + { + name: "exports requested fields", + fields: []string{"author", "commit", "committer", "sha"}, + commit: Commit{ + Author: User{Login: "foo"}, + Committer: User{Login: "bar", ID: "123"}, + Info: CommitInfo{ + Author: CommitUser{Date: authoredAt, Name: "Foo"}, + Committer: CommitUser{Date: committedAt, Name: "Bar"}, + Message: "test message", + }, + Sha: "8dd03144ffdc6c0d", + }, + output: `{"author":{"id":"","is_bot":true,"login":"app/foo","type":"","url":""},"commit":{"author":{"date":"2021-02-27T11:30:00Z","email":"","name":"Foo"},"comment_count":0,"committer":{"date":"2021-02-28T12:30:00Z","email":"","name":"Bar"},"message":"test message","tree":{"sha":""}},"committer":{"id":"123","is_bot":false,"login":"bar","type":"","url":""},"sha":"8dd03144ffdc6c0d"}`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + exported := tt.commit.ExportData(tt.fields) + buf := bytes.Buffer{} + enc := json.NewEncoder(&buf) + require.NoError(t, enc.Encode(exported)) + assert.Equal(t, tt.output, strings.TrimSpace(buf.String())) + }) + } +} + func TestRepositoryExportData(t *testing.T) { var createdAt = time.Date(2021, 2, 28, 12, 30, 0, 0, time.UTC) tests := []struct { @@ -67,7 +103,7 @@ func TestIssueExportData(t *testing.T) { Title: "title", UpdatedAt: updatedAt, }, - output: `{"assignees":[{"id":"123","is_bot":false,"login":"test","type":""},{"id":"","is_bot":true,"login":"app/foo","type":""}],"body":"body","commentsCount":1,"isLocked":true,"labels":[{"color":"","description":"","id":"","name":"label1"},{"color":"","description":"","id":"","name":"label2"}],"repository":{"name":"repo","nameWithOwner":"owner/repo"},"title":"title","updatedAt":"2021-02-28T12:30:00Z"}`, + output: `{"assignees":[{"id":"123","is_bot":false,"login":"test","type":"","url":""},{"id":"","is_bot":true,"login":"app/foo","type":"","url":""}],"body":"body","commentsCount":1,"isLocked":true,"labels":[{"color":"","description":"","id":"","name":"label1"},{"color":"","description":"","id":"","name":"label2"}],"repository":{"name":"repo","nameWithOwner":"owner/repo"},"title":"title","updatedAt":"2021-02-28T12:30:00Z"}`, }, { name: "state when issue", diff --git a/pkg/search/searcher.go b/pkg/search/searcher.go index 778402621..baa84c03c 100644 --- a/pkg/search/searcher.go +++ b/pkg/search/searcher.go @@ -25,6 +25,7 @@ var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) //go:generate moq -rm -out searcher_mock.go . Searcher type Searcher interface { + Commits(Query) (CommitsResult, error) Repositories(Query) (RepositoriesResult, error) Issues(Query) (IssuesResult, error) URL(Query) string @@ -56,6 +57,30 @@ func NewSearcher(client *http.Client, host string) Searcher { } } +func (s searcher) Commits(query Query) (CommitsResult, error) { + result := CommitsResult{} + toRetrieve := query.Limit + var resp *http.Response + var err error + for toRetrieve > 0 { + query.Limit = min(toRetrieve, maxPerPage) + query.Page = nextPage(resp) + if query.Page == 0 { + break + } + page := CommitsResult{} + resp, err = s.search(query, &page) + if err != nil { + return result, err + } + result.IncompleteResults = page.IncompleteResults + result.Total = page.Total + result.Items = append(result.Items, page.Items...) + toRetrieve = toRetrieve - len(page.Items) + } + return result, nil +} + func (s searcher) Repositories(query Query) (RepositoriesResult, error) { result := RepositoriesResult{} toRetrieve := query.Limit diff --git a/pkg/search/searcher_mock.go b/pkg/search/searcher_mock.go index 12c31350d..c1eecdaf4 100644 --- a/pkg/search/searcher_mock.go +++ b/pkg/search/searcher_mock.go @@ -17,6 +17,9 @@ var _ Searcher = &SearcherMock{} // // // make and configure a mocked Searcher // mockedSearcher := &SearcherMock{ +// CommitsFunc: func(query Query) (CommitsResult, error) { +// panic("mock out the Commits method") +// }, // IssuesFunc: func(query Query) (IssuesResult, error) { // panic("mock out the Issues method") // }, @@ -33,6 +36,9 @@ var _ Searcher = &SearcherMock{} // // } type SearcherMock struct { + // CommitsFunc mocks the Commits method. + CommitsFunc func(query Query) (CommitsResult, error) + // IssuesFunc mocks the Issues method. IssuesFunc func(query Query) (IssuesResult, error) @@ -44,6 +50,11 @@ type SearcherMock struct { // calls tracks calls to the methods. calls struct { + // Commits holds details about calls to the Commits method. + Commits []struct { + // Query is the query argument value. + Query Query + } // Issues holds details about calls to the Issues method. Issues []struct { // Query is the query argument value. @@ -60,11 +71,44 @@ type SearcherMock struct { Query Query } } + lockCommits sync.RWMutex lockIssues sync.RWMutex lockRepositories sync.RWMutex lockURL sync.RWMutex } +// Commits calls CommitsFunc. +func (mock *SearcherMock) Commits(query Query) (CommitsResult, error) { + if mock.CommitsFunc == nil { + panic("SearcherMock.CommitsFunc: method is nil but Searcher.Commits was just called") + } + callInfo := struct { + Query Query + }{ + Query: query, + } + mock.lockCommits.Lock() + mock.calls.Commits = append(mock.calls.Commits, callInfo) + mock.lockCommits.Unlock() + return mock.CommitsFunc(query) +} + +// CommitsCalls gets all the calls that were made to Commits. +// Check the length with: +// +// len(mockedSearcher.CommitsCalls()) +func (mock *SearcherMock) CommitsCalls() []struct { + Query Query +} { + var calls []struct { + Query Query + } + mock.lockCommits.RLock() + calls = mock.calls.Commits + mock.lockCommits.RUnlock() + return calls +} + // Issues calls IssuesFunc. func (mock *SearcherMock) Issues(query Query) (IssuesResult, error) { if mock.IssuesFunc == nil { diff --git a/pkg/search/searcher_test.go b/pkg/search/searcher_test.go index 99ea93479..8cc90c533 100644 --- a/pkg/search/searcher_test.go +++ b/pkg/search/searcher_test.go @@ -10,6 +10,163 @@ import ( "github.com/stretchr/testify/assert" ) +func TestSearcherCommits(t *testing.T) { + query := Query{ + Keywords: []string{"keyword"}, + Kind: "commits", + Limit: 30, + Order: "desc", + Sort: "committer-date", + Qualifiers: Qualifiers{ + Author: "foobar", + CommitterDate: ">2021-02-28", + }, + } + + values := url.Values{ + "page": []string{"1"}, + "per_page": []string{"30"}, + "order": []string{"desc"}, + "sort": []string{"committer-date"}, + "q": []string{"keyword author:foobar committer-date:>2021-02-28"}, + } + + tests := []struct { + name string + host string + query Query + result CommitsResult + wantErr bool + errMsg string + httpStubs func(*httpmock.Registry) + }{ + { + name: "searches commits", + query: query, + result: CommitsResult{ + IncompleteResults: false, + Items: []Commit{{Sha: "abc"}}, + Total: 1, + }, + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.QueryMatcher("GET", "search/commits", values), + httpmock.JSONResponse(CommitsResult{ + IncompleteResults: false, + Items: []Commit{{Sha: "abc"}}, + Total: 1, + }), + ) + }, + }, + { + name: "searches commits for enterprise host", + host: "enterprise.com", + query: query, + result: CommitsResult{ + IncompleteResults: false, + Items: []Commit{{Sha: "abc"}}, + Total: 1, + }, + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.QueryMatcher("GET", "api/v3/search/commits", values), + httpmock.JSONResponse(CommitsResult{ + IncompleteResults: false, + Items: []Commit{{Sha: "abc"}}, + Total: 1, + }), + ) + }, + }, + { + name: "paginates results", + query: query, + result: CommitsResult{ + IncompleteResults: false, + Items: []Commit{{Sha: "abc"}, {Sha: "def"}}, + Total: 2, + }, + httpStubs: func(reg *httpmock.Registry) { + firstReq := httpmock.QueryMatcher("GET", "search/commits", values) + firstRes := httpmock.JSONResponse(CommitsResult{ + IncompleteResults: false, + Items: []Commit{{Sha: "abc"}}, + Total: 2, + }, + ) + firstRes = httpmock.WithHeader(firstRes, "Link", `; rel="next"`) + secondReq := httpmock.QueryMatcher("GET", "search/commits", url.Values{ + "page": []string{"2"}, + "per_page": []string{"29"}, + "order": []string{"desc"}, + "sort": []string{"committer-date"}, + "q": []string{"keyword author:foobar committer-date:>2021-02-28"}, + }, + ) + secondRes := httpmock.JSONResponse(CommitsResult{ + IncompleteResults: false, + Items: []Commit{{Sha: "def"}}, + Total: 2, + }, + ) + reg.Register(firstReq, firstRes) + reg.Register(secondReq, secondRes) + }, + }, + { + name: "handles search errors", + query: query, + wantErr: true, + errMsg: heredoc.Doc(` + Invalid search query "keyword author:foobar committer-date:>2021-02-28". + "blah" is not a recognized date/time format. Please provide an ISO 8601 date/time value, such as YYYY-MM-DD.`), + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.QueryMatcher("GET", "search/commits", values), + httpmock.WithHeader( + httpmock.StatusStringResponse(422, + `{ + "message":"Validation Failed", + "errors":[ + { + "message":"\"blah\" is not a recognized date/time format. Please provide an ISO 8601 date/time value, such as YYYY-MM-DD.", + "resource":"Search", + "field":"q", + "code":"invalid" + } + ], + "documentation_url":"https://docs.github.com/v3/search/" + }`, + ), "Content-Type", "application/json"), + ) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reg := &httpmock.Registry{} + defer reg.Verify(t) + if tt.httpStubs != nil { + tt.httpStubs(reg) + } + client := &http.Client{Transport: reg} + if tt.host == "" { + tt.host = "github.com" + } + searcher := NewSearcher(client, tt.host) + result, err := searcher.Commits(tt.query) + if tt.wantErr { + assert.EqualError(t, err, tt.errMsg) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.result, result) + }) + } +} + func TestSearcherRepositories(t *testing.T) { query := Query{ Keywords: []string{"keyword"},