Merge remote-tracking branch 'origin' into pr-commands-isolate

This commit is contained in:
Mislav Marohnić 2020-08-04 15:01:30 +02:00
commit a73584db72
32 changed files with 387 additions and 293 deletions

View file

@ -11,6 +11,7 @@ import (
"regexp"
"strings"
"github.com/cli/cli/internal/ghinstance"
"github.com/henvic/httpretty"
"github.com/shurcooL/graphql"
)
@ -43,25 +44,21 @@ func NewClientFromHTTP(httpClient *http.Client) *Client {
func AddHeader(name, value string) ClientOption {
return func(tr http.RoundTripper) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
// prevent the token from leaking to non-GitHub hosts
// TODO: GHE support
if !strings.EqualFold(name, "Authorization") || strings.HasSuffix(req.URL.Hostname(), ".github.com") {
req.Header.Add(name, value)
}
req.Header.Add(name, value)
return tr.RoundTrip(req)
}}
}
}
// AddHeaderFunc is an AddHeader that gets the string value from a function
func AddHeaderFunc(name string, value func() string) ClientOption {
func AddHeaderFunc(name string, getValue func(*http.Request) (string, error)) ClientOption {
return func(tr http.RoundTripper) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
// prevent the token from leaking to non-GitHub hosts
// TODO: GHE support
if !strings.EqualFold(name, "Authorization") || strings.HasSuffix(req.URL.Hostname(), ".github.com") {
req.Header.Add(name, value())
value, err := getValue(req)
if err != nil {
return nil, err
}
req.Header.Add(name, value)
return tr.RoundTrip(req)
}}
}
@ -244,14 +241,13 @@ func (c Client) HasScopes(wantedScopes ...string) (bool, string, error) {
}
// GraphQL performs a GraphQL request and parses the response
func (c Client) GraphQL(query string, variables map[string]interface{}, data interface{}) error {
url := "https://api.github.com/graphql"
func (c Client) GraphQL(hostname string, query string, variables map[string]interface{}, data interface{}) error {
reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables})
if err != nil {
return err
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(reqBody))
req, err := http.NewRequest("POST", ghinstance.GraphQLEndpoint(hostname), bytes.NewBuffer(reqBody))
if err != nil {
return err
}
@ -267,13 +263,13 @@ func (c Client) GraphQL(query string, variables map[string]interface{}, data int
return handleResponse(resp, data)
}
func graphQLClient(h *http.Client) *graphql.Client {
return graphql.NewClient("https://api.github.com/graphql", h)
func graphQLClient(h *http.Client, hostname string) *graphql.Client {
return graphql.NewClient(ghinstance.GraphQLEndpoint(hostname), h)
}
// REST performs a REST request and parses the response.
func (c Client) REST(method string, p string, body io.Reader, data interface{}) error {
url := "https://api.github.com/" + p
func (c Client) REST(hostname string, method string, p string, body io.Reader, data interface{}) error {
url := ghinstance.RESTPrefix(hostname) + p
req, err := http.NewRequest(method, url, body)
if err != nil {
return err

View file

@ -33,7 +33,7 @@ func TestGraphQL(t *testing.T) {
}{}
http.StubResponse(200, bytes.NewBufferString(`{"data":{"viewer":{"login":"hubot"}}}`))
err := client.GraphQL("QUERY", vars, &response)
err := client.GraphQL("github.com", "QUERY", vars, &response)
eq(t, err, nil)
eq(t, response.Viewer.Login, "hubot")
@ -55,7 +55,7 @@ func TestGraphQLError(t *testing.T) {
]
}`))
err := client.GraphQL("", nil, &response)
err := client.GraphQL("github.com", "", nil, &response)
if err == nil || err.Error() != "GraphQL error: OH NO\nthis is fine" {
t.Fatalf("got %q", err.Error())
}
@ -71,7 +71,7 @@ func TestRESTGetDelete(t *testing.T) {
http.StubResponse(204, bytes.NewBuffer([]byte{}))
r := bytes.NewReader([]byte(`{}`))
err := client.REST("DELETE", "applications/CLIENTID/grant", r, nil)
err := client.REST("github.com", "DELETE", "applications/CLIENTID/grant", r, nil)
eq(t, err, nil)
}
@ -82,7 +82,7 @@ func TestRESTError(t *testing.T) {
http.StubResponse(422, bytes.NewBufferString(`{"message": "OH NO"}`))
var httpErr HTTPError
err := client.REST("DELETE", "repos/branch", nil, nil)
err := client.REST("github.com", "DELETE", "repos/branch", nil, nil)
if err == nil || !errors.As(err, &httpErr) {
t.Fatalf("got %v", err)
}

View file

@ -1,52 +0,0 @@
package api
import (
"bytes"
"encoding/json"
)
// Gist represents a GitHub's gist.
type Gist struct {
Description string `json:"description,omitempty"`
Public bool `json:"public,omitempty"`
Files map[GistFilename]GistFile `json:"files,omitempty"`
HTMLURL string `json:"html_url,omitempty"`
}
type GistFilename string
type GistFile struct {
Content string `json:"content,omitempty"`
}
// Create a gist for authenticated user.
//
// GitHub API docs: https://developer.github.com/v3/gists/#create-a-gist
func GistCreate(client *Client, description string, public bool, files map[string]string) (*Gist, error) {
gistFiles := map[GistFilename]GistFile{}
for filename, content := range files {
gistFiles[GistFilename(filename)] = GistFile{content}
}
path := "gists"
body := &Gist{
Description: description,
Public: public,
Files: gistFiles,
}
result := Gist{}
requestByte, err := json.Marshal(body)
if err != nil {
return nil, err
}
requestBody := bytes.NewReader(requestByte)
err = client.REST("POST", path, requestBody, &result)
if err != nil {
return nil, err
}
return &result, nil
}

View file

@ -112,7 +112,7 @@ func IssueCreate(client *Client, repo *Repository, params map[string]interface{}
}
}{}
err := client.GraphQL(query, variables, &result)
err := client.GraphQL(repo.RepoHost(), query, variables, &result)
if err != nil {
return nil, err
}
@ -171,7 +171,7 @@ func IssueStatus(client *Client, repo ghrepo.Interface, currentUsername string)
}
var resp response
err := client.GraphQL(query, variables, &resp)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
@ -270,7 +270,7 @@ func IssueList(client *Client, repo ghrepo.Interface, state string, labels []str
loop:
for {
variables["limit"] = pageLimit
err := client.GraphQL(query, variables, &response)
err := client.GraphQL(repo.RepoHost(), query, variables, &response)
if err != nil {
return nil, err
}
@ -361,7 +361,7 @@ func IssueByNumber(client *Client, repo ghrepo.Interface, number int) (*Issue, e
}
var resp response
err := client.GraphQL(query, variables, &resp)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
@ -389,7 +389,7 @@ func IssueClose(client *Client, repo ghrepo.Interface, issue Issue) error {
},
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
err := gql.MutateNamed(context.Background(), "IssueClose", &mutation, variables)
if err != nil {
@ -414,7 +414,7 @@ func IssueReopen(client *Client, repo ghrepo.Interface, issue Issue) error {
},
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
err := gql.MutateNamed(context.Background(), "IssueReopen", &mutation, variables)
return err

View file

@ -3,11 +3,12 @@ package api
import (
"context"
"github.com/cli/cli/internal/ghrepo"
"github.com/shurcooL/githubv4"
)
// OrganizationProjects fetches all open projects for an organization
func OrganizationProjects(client *Client, owner string) ([]RepoProject, error) {
func OrganizationProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, error) {
var query struct {
Organization struct {
Projects struct {
@ -21,11 +22,11 @@ func OrganizationProjects(client *Client, owner string) ([]RepoProject, error) {
}
variables := map[string]interface{}{
"owner": githubv4.String(owner),
"owner": githubv4.String(repo.RepoOwner()),
"endCursor": (*githubv4.String)(nil),
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
var projects []RepoProject
for {
@ -50,7 +51,7 @@ type OrgTeam struct {
}
// OrganizationTeams fetches all the teams in an organization
func OrganizationTeams(client *Client, owner string) ([]OrgTeam, error) {
func OrganizationTeams(client *Client, repo ghrepo.Interface) ([]OrgTeam, error) {
var query struct {
Organization struct {
Teams struct {
@ -64,11 +65,11 @@ func OrganizationTeams(client *Client, owner string) ([]OrgTeam, error) {
}
variables := map[string]interface{}{
"owner": githubv4.String(owner),
"owner": githubv4.String(repo.RepoOwner()),
"endCursor": (*githubv4.String)(nil),
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
var teams []OrgTeam
for {

View file

@ -355,7 +355,7 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
}
var resp response
err := client.GraphQL(query, variables, &resp)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
@ -492,7 +492,7 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu
}
var resp response
err := client.GraphQL(query, variables, &resp)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
@ -605,7 +605,7 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea
}
var resp response
err := client.GraphQL(query, variables, &resp)
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}
@ -655,7 +655,7 @@ func CreatePullRequest(client *Client, repo *Repository, params map[string]inter
}
}{}
err := client.GraphQL(query, variables, &result)
err := client.GraphQL(repo.RepoHost(), query, variables, &result)
if err != nil {
return nil, err
}
@ -681,7 +681,7 @@ func CreatePullRequest(client *Client, repo *Repository, params map[string]inter
variables := map[string]interface{}{
"input": updateParams,
}
err := client.GraphQL(updateQuery, variables, &result)
err := client.GraphQL(repo.RepoHost(), updateQuery, variables, &result)
if err != nil {
return nil, err
}
@ -706,7 +706,7 @@ func CreatePullRequest(client *Client, repo *Repository, params map[string]inter
variables := map[string]interface{}{
"input": reviewParams,
}
err := client.GraphQL(reviewQuery, variables, &result)
err := client.GraphQL(repo.RepoHost(), reviewQuery, variables, &result)
if err != nil {
return nil, err
}
@ -726,7 +726,7 @@ func isBlank(v interface{}) bool {
}
}
func AddReview(client *Client, pr *PullRequest, input *PullRequestReviewInput) error {
func AddReview(client *Client, repo ghrepo.Interface, pr *PullRequest, input *PullRequestReviewInput) error {
var mutation struct {
AddPullRequestReview struct {
ClientMutationID string
@ -750,11 +750,11 @@ func AddReview(client *Client, pr *PullRequest, input *PullRequestReviewInput) e
},
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
return gql.MutateNamed(context.Background(), "PullRequestReviewAdd", &mutation, variables)
}
func PullRequestList(client *Client, vars map[string]interface{}, limit int) (*PullRequestAndTotalCount, error) {
func PullRequestList(client *Client, repo ghrepo.Interface, vars map[string]interface{}, limit int) (*PullRequestAndTotalCount, error) {
type prBlock struct {
Edges []struct {
Node PullRequest
@ -851,10 +851,8 @@ func PullRequestList(client *Client, vars map[string]interface{}, limit int) (*P
}
}
}`
owner := vars["owner"].(string)
repo := vars["repo"].(string)
search := []string{
fmt.Sprintf("repo:%s/%s", owner, repo),
fmt.Sprintf("repo:%s/%s", repo.RepoOwner(), repo.RepoName()),
fmt.Sprintf("assignee:%s", assignee),
"is:pr",
"sort:created-desc",
@ -880,6 +878,8 @@ func PullRequestList(client *Client, vars map[string]interface{}, limit int) (*P
}
variables["q"] = strings.Join(search, " ")
} else {
variables["owner"] = repo.RepoOwner()
variables["repo"] = repo.RepoName()
for name, val := range vars {
variables[name] = val
}
@ -888,7 +888,7 @@ loop:
for {
variables["limit"] = pageLimit
var data response
err := client.GraphQL(query, variables, &data)
err := client.GraphQL(repo.RepoHost(), query, variables, &data)
if err != nil {
return nil, err
}
@ -937,7 +937,7 @@ func PullRequestClose(client *Client, repo ghrepo.Interface, pr *PullRequest) er
},
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
err := gql.MutateNamed(context.Background(), "PullRequestClose", &mutation, variables)
return err
@ -958,7 +958,7 @@ func PullRequestReopen(client *Client, repo ghrepo.Interface, pr *PullRequest) e
},
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
err := gql.MutateNamed(context.Background(), "PullRequestReopen", &mutation, variables)
return err
@ -988,7 +988,7 @@ func PullRequestMerge(client *Client, repo ghrepo.Interface, pr *PullRequest, m
},
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
err := gql.MutateNamed(context.Background(), "PullRequestMerge", &mutation, variables)
return err
@ -1009,13 +1009,13 @@ func PullRequestReady(client *Client, repo ghrepo.Interface, pr *PullRequest) er
},
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
return gql.MutateNamed(context.Background(), "PullRequestReadyForReview", &mutation, variables)
}
func BranchDeleteRemote(client *Client, repo ghrepo.Interface, branch string) error {
path := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", repo.RepoOwner(), repo.RepoName(), branch)
return client.REST("DELETE", path, nil, nil)
return client.REST(repo.RepoHost(), "DELETE", path, nil, nil)
}
func min(a, b int) int {

View file

@ -106,7 +106,7 @@ func GitHubRepo(client *Client, repo ghrepo.Interface) (*Repository, error) {
result := struct {
Repository Repository
}{}
err := client.GraphQL(query, variables, &result)
err := client.GraphQL(repo.RepoHost(), query, variables, &result)
if err != nil {
return nil, err
@ -145,7 +145,7 @@ func RepoParent(client *Client, repo ghrepo.Interface) (ghrepo.Interface, error)
"name": githubv4.String(repo.RepoName()),
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
err := gql.QueryNamed(context.Background(), "RepositoryFindParent", &query, variables)
if err != nil {
return nil, err
@ -189,7 +189,7 @@ func RepoNetwork(client *Client, repos []ghrepo.Interface) (RepoNetworkResult, e
graphqlResult := make(map[string]*json.RawMessage)
var result RepoNetworkResult
err := client.GraphQL(fmt.Sprintf(`
err := client.GraphQL(hostname, fmt.Sprintf(`
fragment repo on Repository {
id
name
@ -285,7 +285,7 @@ func ForkRepo(client *Client, repo ghrepo.Interface) (*Repository, error) {
path := fmt.Sprintf("repos/%s/forks", ghrepo.FullName(repo))
body := bytes.NewBufferString(`{}`)
result := repositoryV3{}
err := client.REST("POST", path, body, &result)
err := client.REST(repo.RepoHost(), "POST", path, body, &result)
if err != nil {
return nil, err
}
@ -318,7 +318,7 @@ func RepoFindFork(client *Client, repo ghrepo.Interface) (*Repository, error) {
"repo": repo.RepoName(),
}
if err := client.GraphQL(`
if err := client.GraphQL(repo.RepoHost(), `
query RepositoryFindFork($owner: String!, $repo: String!) {
repository(owner: $owner, name: $repo) {
forks(first: 1, affiliations: [OWNER, COLLABORATOR]) {
@ -464,7 +464,7 @@ func RepoMetadata(client *Client, repo ghrepo.Interface, input RepoMetadataInput
if input.Reviewers {
count++
go func() {
teams, err := OrganizationTeams(client, repo.RepoOwner())
teams, err := OrganizationTeams(client, repo)
// TODO: better detection of non-org repos
if err != nil && !strings.HasPrefix(err.Error(), "Could not resolve to an Organization") {
errc <- fmt.Errorf("error fetching organization teams: %w", err)
@ -495,7 +495,7 @@ func RepoMetadata(client *Client, repo ghrepo.Interface, input RepoMetadataInput
}
result.Projects = projects
orgProjects, err := OrganizationProjects(client, repo.RepoOwner())
orgProjects, err := OrganizationProjects(client, repo)
// TODO: better detection of non-org repos
if err != nil && !strings.HasPrefix(err.Error(), "Could not resolve to an Organization") {
errc <- fmt.Errorf("error fetching organization projects: %w", err)
@ -591,7 +591,7 @@ func RepoResolveMetadataIDs(client *Client, repo ghrepo.Interface, input RepoRes
fmt.Fprint(query, "}\n")
response := make(map[string]json.RawMessage)
err = client.GraphQL(query.String(), nil, &response)
err = client.GraphQL(repo.RepoHost(), query.String(), nil, &response)
if err != nil {
return result, err
}
@ -654,7 +654,7 @@ func RepoProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, error)
"endCursor": (*githubv4.String)(nil),
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
var projects []RepoProject
for {
@ -698,7 +698,7 @@ func RepoAssignableUsers(client *Client, repo ghrepo.Interface) ([]RepoAssignee,
"endCursor": (*githubv4.String)(nil),
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
var users []RepoAssignee
for {
@ -742,7 +742,7 @@ func RepoLabels(client *Client, repo ghrepo.Interface) ([]RepoLabel, error) {
"endCursor": (*githubv4.String)(nil),
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
var labels []RepoLabel
for {
@ -786,7 +786,7 @@ func RepoMilestones(client *Client, repo ghrepo.Interface) ([]RepoMilestone, err
"endCursor": (*githubv4.String)(nil),
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, repo.RepoHost())
var milestones []RepoMilestone
for {

View file

@ -4,13 +4,13 @@ import (
"context"
)
func CurrentLoginName(client *Client) (string, error) {
func CurrentLoginName(client *Client, hostname string) (string, error) {
var query struct {
Viewer struct {
Login string
}
}
gql := graphQLClient(client.http)
gql := graphQLClient(client.http, hostname)
err := gql.QueryNamed(context.Background(), "UserCurrent", &query, nil)
return query.Viewer.Login, err
}

View file

@ -13,6 +13,7 @@ import (
"os"
"strings"
"github.com/cli/cli/internal/ghinstance"
"github.com/cli/cli/pkg/browser"
)
@ -52,9 +53,18 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
scopes = strings.Join(oa.Scopes, " ")
}
localhost := "127.0.0.1"
callbackPath := "/callback"
if ghinstance.IsEnterprise(oa.Hostname) {
// the OAuth app on Enterprise hosts is still registered with a legacy callback URL
// see https://github.com/cli/cli/pull/222, https://github.com/cli/cli/pull/650
localhost = "localhost"
callbackPath = "/"
}
q := url.Values{}
q.Set("client_id", oa.ClientID)
q.Set("redirect_uri", fmt.Sprintf("http://127.0.0.1:%d/callback", port))
q.Set("redirect_uri", fmt.Sprintf("http://%s:%d%s", localhost, port, callbackPath))
q.Set("scope", scopes)
q.Set("state", state)
@ -73,7 +83,7 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
_ = http.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
oa.logf("server handler: %s\n", r.URL.Path)
if r.URL.Path != "/callback" {
if r.URL.Path != callbackPath {
w.WriteHeader(404)
return
}

View file

@ -281,7 +281,7 @@ func issueStatus(cmd *cobra.Command, args []string) error {
return err
}
currentUser, err := api.CurrentLoginName(apiClient)
currentUser, err := api.CurrentLoginName(apiClient, baseRepo.RepoHost())
if err != nil {
return err
}
@ -376,7 +376,7 @@ func listHeader(repoName string, itemName string, matchCount int, totalMatchCoun
return fmt.Sprintf("Showing %d of %s in %s that %s your search", matchCount, utils.Pluralize(totalMatchCount, itemName), repoName, matchVerb)
}
return fmt.Sprintf("Showing %d of %s in %s", matchCount, utils.Pluralize(totalMatchCount, itemName), repoName)
return fmt.Sprintf("Showing %d of %s in %s", matchCount, utils.Pluralize(totalMatchCount, fmt.Sprintf("open %s", itemName)), repoName)
}
func printRawIssuePreview(out io.Writer, issue *api.Issue) error {

View file

@ -145,7 +145,7 @@ func TestIssueList_tty(t *testing.T) {
}
eq(t, output.Stderr(), `
Showing 3 of 3 issues in OWNER/REPO
Showing 3 of 3 open issues in OWNER/REPO
`)
@ -847,7 +847,7 @@ func Test_listHeader(t *testing.T) {
totalMatchCount: 23,
hasFilters: false,
},
want: "Showing 1 of 23 genies in REPO",
want: "Showing 1 of 23 open genies in REPO",
},
{
name: "one result after filters",
@ -869,7 +869,7 @@ func Test_listHeader(t *testing.T) {
totalMatchCount: 1,
hasFilters: false,
},
want: "Showing 1 of 1 chip in REPO",
want: "Showing 1 of 1 open chip in REPO",
},
{
name: "one result in total after filters",
@ -891,7 +891,7 @@ func Test_listHeader(t *testing.T) {
totalMatchCount: 23,
hasFilters: false,
},
want: "Showing 4 of 23 plants in REPO",
want: "Showing 4 of 23 open plants in REPO",
},
{
name: "multiple results after filters",

View file

@ -265,8 +265,6 @@ func prList(cmd *cobra.Command, args []string) error {
}
params := map[string]interface{}{
"owner": baseRepo.RepoOwner(),
"repo": baseRepo.RepoName(),
"state": graphqlState,
}
if len(labels) > 0 {
@ -279,7 +277,7 @@ func prList(cmd *cobra.Command, args []string) error {
params["assignee"] = assignee
}
listResult, err := api.PullRequestList(apiClient, params, limit)
listResult, err := api.PullRequestList(apiClient, baseRepo, params, limit)
if err != nil {
return err
}

View file

@ -307,7 +307,7 @@ func TestPRList(t *testing.T) {
}
assert.Equal(t, `
Showing 3 of 3 pull requests in OWNER/REPO
Showing 3 of 3 open pull requests in OWNER/REPO
`, output.Stderr())

View file

@ -18,6 +18,7 @@ import (
"github.com/cli/cli/context"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/internal/ghinstance"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/internal/run"
apiCmd "github.com/cli/cli/pkg/cmd/api"
@ -40,9 +41,6 @@ import (
"github.com/spf13/pflag"
)
// TODO these are sprinkled across command, context, config, and ghrepo
const defaultHostname = "github.com"
// Version is dynamically set by the toolchain or overridden by the Makefile.
var Version = "DEV"
@ -51,6 +49,8 @@ var BuildDate = "" // YYYY-MM-DD
var versionOutput = ""
var defaultStreams *iostreams.IOStreams
func init() {
if Version == "DEV" {
if info, ok := debug.ReadBuildInfo(); ok && info.Main.Version != "(devel)" {
@ -82,22 +82,21 @@ func init() {
return &cmdutil.FlagError{Err: err}
})
defaultStreams = iostreams.System()
// TODO: iron out how a factory incorporates context
cmdFactory := &cmdutil.Factory{
IOStreams: iostreams.System(),
IOStreams: defaultStreams,
HttpClient: func() (*http.Client, error) {
token := os.Getenv("GITHUB_TOKEN")
if len(token) == 0 {
// TODO: decouple from `context`
ctx := context.New()
var err error
// TODO: pass IOStreams to this so that the auth flow knows if it's interactive or not
token, err = ctx.AuthToken()
if err != nil {
return nil, err
}
// TODO: decouple from `context`
ctx := context.New()
cfg, err := ctx.Config()
if err != nil {
return nil, err
}
return httpClient(token), nil
// TODO: avoid setting Accept header for `api` command
return httpClient(defaultStreams, cfg, true), nil
},
BaseRepo: func() (ghrepo.Interface, error) {
// TODO: decouple from `context`
@ -233,14 +232,11 @@ var initContext = func() context.Context {
if repo := os.Getenv("GH_REPO"); repo != "" {
ctx.SetBaseRepo(repo)
}
if token := os.Getenv("GITHUB_TOKEN"); token != "" {
ctx.SetAuthToken(token)
}
return ctx
}
// BasicClient returns an API client that borrows from but does not depend on
// user configuration
// BasicClient returns an API client for github.com only that borrows from but
// does not depend on user configuration
func BasicClient() (*api.Client, error) {
var opts []api.ClientOption
if verbose := os.Getenv("DEBUG"); verbose != "" {
@ -251,7 +247,7 @@ func BasicClient() (*api.Client, error) {
token := os.Getenv("GITHUB_TOKEN")
if token == "" {
if c, err := config.ParseDefaultConfig(); err == nil {
token, _ = c.Get(defaultHostname, "oauth_token")
token, _ = c.Get(ghinstance.Default(), "oauth_token")
}
}
if token != "" {
@ -268,72 +264,68 @@ func contextForCommand(cmd *cobra.Command) context.Context {
return ctx
}
// for cmdutil-powered commands
func httpClient(token string) *http.Client {
// generic authenticated HTTP client for commands
func httpClient(io *iostreams.IOStreams, cfg config.Config, setAccept bool) *http.Client {
var opts []api.ClientOption
if verbose := os.Getenv("DEBUG"); verbose != "" {
opts = append(opts, apiVerboseLog())
}
opts = append(opts,
api.AddHeader("Authorization", fmt.Sprintf("token %s", token)),
api.AddHeader("User-Agent", fmt.Sprintf("GitHub CLI %s", Version)),
api.AddHeaderFunc("Authorization", func(req *http.Request) (string, error) {
if token := os.Getenv("GITHUB_TOKEN"); token != "" {
return fmt.Sprintf("token %s", token), nil
}
hostname := ghinstance.NormalizeHostname(req.URL.Hostname())
token, err := cfg.Get(hostname, "oauth_token")
if token == "" {
var notFound *config.NotFoundError
// TODO: check if stdout is TTY too
if errors.As(err, &notFound) && io.IsStdinTTY() {
// interactive OAuth flow
token, err = config.AuthFlowWithConfig(cfg, hostname, "Notice: authentication required")
}
if err != nil {
return "", err
}
if token == "" {
// TODO: instruct user how to manually authenticate
return "", fmt.Errorf("authentication required for %s", hostname)
}
}
return fmt.Sprintf("token %s", token), nil
}),
)
if setAccept {
opts = append(opts,
api.AddHeaderFunc("Accept", func(req *http.Request) (string, error) {
// antiope-preview: Checks
accept := "application/vnd.github.antiope-preview+json"
if ghinstance.IsEnterprise(req.URL.Hostname()) {
// shadow-cat-preview: Draft pull requests
accept += ", application/vnd.github.shadow-cat-preview"
}
return accept, nil
}),
)
}
return api.NewHTTPClient(opts...)
}
// overridden in tests
// LEGACY; overridden in tests
var apiClientForContext = func(ctx context.Context) (*api.Client, error) {
token, err := ctx.AuthToken()
cfg, err := ctx.Config()
if err != nil {
return nil, err
}
var opts []api.ClientOption
if verbose := os.Getenv("DEBUG"); verbose != "" {
opts = append(opts, apiVerboseLog())
}
getAuthValue := func() string {
return fmt.Sprintf("token %s", token)
}
tokenFromEnv := func() bool {
return os.Getenv("GITHUB_TOKEN") == token
}
checkScopesFunc := func(appID string) error {
if config.IsGitHubApp(appID) && !tokenFromEnv() && utils.IsTerminal(os.Stdin) && utils.IsTerminal(os.Stderr) {
cfg, err := ctx.Config()
if err != nil {
return err
}
newToken, err := config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: additional authorization required")
if err != nil {
return err
}
// update configuration in memory
token = newToken
} else {
fmt.Fprintln(os.Stderr, "Warning: gh now requires the `read:org` OAuth scope.")
fmt.Fprintln(os.Stderr, "Visit https://github.com/settings/tokens and edit your token to enable `read:org`")
if tokenFromEnv() {
fmt.Fprintln(os.Stderr, "or generate a new token for the GITHUB_TOKEN environment variable")
} else {
fmt.Fprintln(os.Stderr, "or generate a new token and paste it via `gh config set -h github.com oauth_token MYTOKEN`")
}
}
return nil
}
opts = append(opts,
api.CheckScopes("read:org", checkScopesFunc),
api.AddHeaderFunc("Authorization", getAuthValue),
api.AddHeader("User-Agent", fmt.Sprintf("GitHub CLI %s", Version)),
// antiope-preview: Checks
api.AddHeader("Accept", "application/vnd.github.antiope-preview+json"),
)
return api.NewClient(opts...), nil
http := httpClient(defaultStreams, cfg, true)
return api.NewClientFromHTTP(http), nil
}
func apiVerboseLog() api.ClientOption {
@ -425,7 +417,8 @@ func determineEditor(cmd *cobra.Command) (string, error) {
if err != nil {
return "", fmt.Errorf("could not read config: %w", err)
}
editorCommand, _ = cfg.Get(defaultHostname, "editor")
// TODO: consider supporting setting an editor per GHE host
editorCommand, _ = cfg.Get(ghinstance.Default(), "editor")
}
return editorCommand, nil

View file

@ -16,10 +16,9 @@ func NewBlank() *blankContext {
// A Context implementation that queries the filesystem
type blankContext struct {
authToken string
branch string
baseRepo ghrepo.Interface
remotes Remotes
branch string
baseRepo ghrepo.Interface
remotes Remotes
}
func (c *blankContext) Config() (config.Config, error) {
@ -30,14 +29,6 @@ func (c *blankContext) Config() (config.Config, error) {
return cfg, nil
}
func (c *blankContext) AuthToken() (string, error) {
return c.authToken, nil
}
func (c *blankContext) SetAuthToken(t string) {
c.authToken = t
}
func (c *blankContext) Branch() (string, error) {
if c.branch == "" {
return "", fmt.Errorf("branch was not initialized: %w", git.ErrNotOnAnyBranch)

View file

@ -13,13 +13,8 @@ import (
"github.com/cli/cli/internal/ghrepo"
)
// TODO these are sprinkled across command, context, config, and ghrepo
const defaultHostname = "github.com"
// Context represents the interface for querying information about the current environment
type Context interface {
AuthToken() (string, error)
SetAuthToken(string)
Branch() (string, error)
SetBranch(string)
Remotes() (Remotes, error)
@ -164,11 +159,10 @@ func New() Context {
// A Context implementation that queries the filesystem
type fsContext struct {
config config.Config
remotes Remotes
branch string
baseRepo ghrepo.Interface
authToken string
config config.Config
remotes Remotes
branch string
baseRepo ghrepo.Interface
}
func (c *fsContext) Config() (config.Config, error) {
@ -180,37 +174,10 @@ func (c *fsContext) Config() (config.Config, error) {
return nil, err
}
c.config = cfg
c.authToken = ""
}
return c.config, nil
}
func (c *fsContext) AuthToken() (string, error) {
if c.authToken != "" {
return c.authToken, nil
}
cfg, err := c.Config()
if err != nil {
return "", err
}
var notFound *config.NotFoundError
token, err := cfg.Get(defaultHostname, "oauth_token")
if token == "" || errors.As(err, &notFound) {
// interactive OAuth flow
return config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: authentication required")
} else if err != nil {
return "", err
}
return token, nil
}
func (c *fsContext) SetAuthToken(t string) {
c.authToken = t
}
func (c *fsContext) Branch() (string, error) {
if c.branch != "" {
return c.branch, nil
@ -242,11 +209,16 @@ func (c *fsContext) Remotes() (Remotes, error) {
sshTranslate := git.ParseSSHConfig().Translator()
resolvedRemotes := translateRemotes(gitRemotes, sshTranslate)
// ignore non-github.com remotes
// TODO: GHE compatibility
// determine hostname by looking at the "main" remote
var hostname string
if mainRemote, err := resolvedRemotes.FindByName("upstream", "github", "origin", "*"); err == nil {
hostname = mainRemote.RepoHost()
}
// filter the rest of the remotes to just that hostname
filteredRemotes := Remotes{}
for _, r := range resolvedRemotes {
if r.RepoHost() != defaultHostname {
if r.RepoHost() != hostname {
continue
}
filteredRemotes = append(filteredRemotes, r)
@ -255,7 +227,6 @@ func (c *fsContext) Remotes() (Remotes, error) {
}
if len(c.remotes) == 0 {
// TODO: GHE compatibility
return nil, errors.New("no git remote found for a github.com repository")
}
return c.remotes, nil

View file

@ -74,7 +74,7 @@ func authFlow(oauthHost, notice string) (string, string, error) {
return "", "", err
}
userLogin, err := getViewer(token)
userLogin, err := getViewer(oauthHost, token)
if err != nil {
return "", "", err
}
@ -87,9 +87,9 @@ func AuthFlowComplete() {
_ = waitForEnter(os.Stdin)
}
func getViewer(token string) (string, error) {
func getViewer(hostname, token string) (string, error) {
http := api.NewClient(api.AddHeader("Authorization", fmt.Sprintf("token %s", token)))
return api.CurrentLoginName(http)
return api.CurrentLoginName(http, hostname)
}
func waitForEnter(r io.Reader) error {

View file

@ -0,0 +1,41 @@
package ghinstance
import (
"fmt"
"strings"
)
const defaultHostname = "github.com"
// Default returns the host name of the default GitHub instance
func Default() string {
return defaultHostname
}
// IsEnterprise reports whether a non-normalized host name looks like a GHE instance
func IsEnterprise(h string) bool {
return NormalizeHostname(h) != defaultHostname
}
// NormalizeHostname returns the canonical host name of a GitHub instance
func NormalizeHostname(h string) string {
hostname := strings.ToLower(h)
if strings.HasSuffix(hostname, "."+defaultHostname) {
return defaultHostname
}
return hostname
}
func GraphQLEndpoint(hostname string) string {
if IsEnterprise(hostname) {
return fmt.Sprintf("https://%s/api/graphql", hostname)
}
return "https://api.github.com/graphql"
}
func RESTPrefix(hostname string) string {
if IsEnterprise(hostname) {
return fmt.Sprintf("https://%s/api/v3/", hostname)
}
return "https://api.github.com/"
}

View file

@ -0,0 +1,121 @@
package ghinstance
import (
"testing"
)
func TestIsEnterprise(t *testing.T) {
tests := []struct {
host string
want bool
}{
{
host: "github.com",
want: false,
},
{
host: "api.github.com",
want: false,
},
{
host: "ghe.io",
want: true,
},
{
host: "example.com",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
if got := IsEnterprise(tt.host); got != tt.want {
t.Errorf("IsEnterprise() = %v, want %v", got, tt.want)
}
})
}
}
func TestNormalizeHostname(t *testing.T) {
tests := []struct {
host string
want string
}{
{
host: "GitHub.com",
want: "github.com",
},
{
host: "api.github.com",
want: "github.com",
},
{
host: "ssh.github.com",
want: "github.com",
},
{
host: "upload.github.com",
want: "github.com",
},
{
host: "GHE.IO",
want: "ghe.io",
},
{
host: "git.my.org",
want: "git.my.org",
},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
if got := NormalizeHostname(tt.host); got != tt.want {
t.Errorf("NormalizeHostname() = %v, want %v", got, tt.want)
}
})
}
}
func TestGraphQLEndpoint(t *testing.T) {
tests := []struct {
host string
want string
}{
{
host: "github.com",
want: "https://api.github.com/graphql",
},
{
host: "ghe.io",
want: "https://ghe.io/api/graphql",
},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
if got := GraphQLEndpoint(tt.host); got != tt.want {
t.Errorf("GraphQLEndpoint() = %v, want %v", got, tt.want)
}
})
}
}
func TestRESTPrefix(t *testing.T) {
tests := []struct {
host string
want string
}{
{
host: "github.com",
want: "https://api.github.com/",
},
{
host: "ghe.io",
want: "https://ghe.io/api/v3/",
},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
if got := RESTPrefix(tt.host); got != tt.want {
t.Errorf("RESTPrefix() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -20,20 +20,21 @@ func Command(url string) (*exec.Cmd, error) {
// ForOS produces an exec.Cmd to open the web browser for different OS
func ForOS(goos, url string) *exec.Cmd {
exe := "open"
var args []string
switch goos {
case "darwin":
args = []string{"open"}
args = append(args, url)
case "windows":
args = []string{"cmd", "/c", "start"}
exe = "cmd"
r := strings.NewReplacer("&", "^&")
url = r.Replace(url)
args = append(args, "/c", "start", r.Replace(url))
default:
args = []string{"xdg-open"}
exe = "xdg-open"
args = append(args, url)
}
args = append(args, url)
cmd := exec.Command(args[0], args[1:]...)
cmd := exec.Command(exe, args...)
cmd.Stderr = os.Stderr
return cmd
}

View file

@ -74,6 +74,9 @@ on the format of the value:
- if the value starts with "@", the rest of the value is interpreted as a
filename to read the value from. Pass "-" to read from standard input.
For GraphQL requests, all fields other than "query" and "operationName" are
interpreted as GraphQL variables.
Raw request body may be passed from the outside via a file specified by '--input'.
Pass "-" to read from standard input. In this mode, parameters specified via
'--field' flags are serialized into URL query parameters.

View file

@ -13,10 +13,10 @@ import (
func httpRequest(client *http.Client, method string, p string, params interface{}, headers []string) (*http.Response, error) {
var requestURL string
// TODO: GHE support
if strings.Contains(p, "://") {
requestURL = p
} else {
// TODO: GHE support
requestURL = "https://api.github.com/" + p
}
@ -87,7 +87,7 @@ func groupGraphQLVariables(params map[string]interface{}) map[string]interface{}
for key, val := range params {
switch key {
case "query":
case "query", "operationName":
topLevel[key] = val
default:
variables[key] = val

View file

@ -55,6 +55,21 @@ func Test_groupGraphQLVariables(t *testing.T) {
},
},
},
{
name: "query + operationName + variables",
args: map[string]interface{}{
"query": "query Q1{} query Q2{}",
"operationName": "Q1",
"power": 9001,
},
want: map[string]interface{}{
"query": "query Q1{} query Q2{}",
"operationName": "Q1",
"variables": map[string]interface{}{
"power": 9001,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -11,6 +11,7 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/api"
"github.com/cli/cli/internal/ghinstance"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
"github.com/cli/cli/utils"
@ -103,7 +104,8 @@ func createRun(opts *CreateOptions) error {
return err
}
gist, err := apiCreate(httpClient, opts.Description, opts.Public, files)
// TODO: GHE support
gist, err := apiCreate(httpClient, ghinstance.Default(), opts.Description, opts.Public, files)
if err != nil {
var httpError api.HTTPError
if errors.As(err, &httpError) {

View file

@ -22,7 +22,7 @@ type GistFile struct {
Content string `json:"content,omitempty"`
}
func apiCreate(httpClient *http.Client, description string, public bool, files map[string]string) (*Gist, error) {
func apiCreate(httpClient *http.Client, hostname string, description string, public bool, files map[string]string) (*Gist, error) {
gistFiles := map[GistFilename]GistFile{}
for filename, content := range files {
@ -44,7 +44,7 @@ func apiCreate(httpClient *http.Client, description string, public bool, files m
requestBody := bytes.NewReader(requestByte)
apiClient := api.NewClientFromHTTP(httpClient)
err = apiClient.REST("POST", path, requestBody, &result)
err = apiClient.REST(hostname, "POST", path, requestBody, &result)
if err != nil {
return nil, err
}

View file

@ -99,7 +99,7 @@ func reviewRun(opts *ReviewOptions) error {
return fmt.Errorf("did not understand desired review action: %w", err)
}
pr, _, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg)
pr, baseRepo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg)
if err != nil {
return err
}
@ -119,7 +119,7 @@ func reviewRun(opts *ReviewOptions) error {
}
}
err = api.AddReview(apiClient, pr, reviewData)
err = api.AddReview(apiClient, baseRepo, pr, reviewData)
if err != nil {
return fmt.Errorf("failed to create review: %w", err)
}

View file

@ -8,6 +8,7 @@ import (
"github.com/cli/cli/api"
"github.com/cli/cli/git"
"github.com/cli/cli/internal/config"
"github.com/cli/cli/internal/ghinstance"
"github.com/cli/cli/internal/ghrepo"
"github.com/cli/cli/pkg/cmdutil"
"github.com/cli/cli/pkg/iostreams"
@ -79,7 +80,8 @@ func cloneRun(opts *CloneOptions) error {
cloneURL := opts.Repository
if !strings.Contains(cloneURL, ":") {
if !strings.Contains(cloneURL, "/") {
currentUser, err := api.CurrentLoginName(apiClient)
// TODO: GHE compat
currentUser, err := api.CurrentLoginName(apiClient, ghinstance.Default())
if err != nil {
return err
}

View file

@ -5,6 +5,7 @@ import (
"net/http"
"github.com/cli/cli/api"
"github.com/cli/cli/internal/ghinstance"
)
// repoCreateInput represents input parameters for repoCreate
@ -50,7 +51,10 @@ func repoCreate(client *http.Client, input repoCreateInput) (*api.Repository, er
"input": input,
}
err := apiClient.GraphQL(`
// TODO: GHE support
hostname := ghinstance.Default()
err := apiClient.GraphQL(hostname, `
mutation RepositoryCreate($input: CreateRepositoryInput!) {
createRepository(input: $input) {
repository {
@ -66,8 +70,7 @@ func repoCreate(client *http.Client, input repoCreateInput) (*api.Repository, er
return nil, err
}
// FIXME: support Enterprise hosts
return api.InitRepoHostname(&response.CreateRepository.Repository, "github.com"), nil
return api.InitRepoHostname(&response.CreateRepository.Repository, hostname), nil
}
// using API v3 here because the equivalent in GraphQL needs `read:org` scope
@ -75,7 +78,8 @@ func resolveOrganization(client *api.Client, orgName string) (string, error) {
var response struct {
NodeID string `json:"node_id"`
}
err := client.REST("GET", fmt.Sprintf("users/%s", orgName), nil, &response)
// TODO: GHE support
err := client.REST(ghinstance.Default(), "GET", fmt.Sprintf("users/%s", orgName), nil, &response)
return response.NodeID, err
}
@ -87,6 +91,7 @@ func resolveOrganizationTeam(client *api.Client, orgName, teamSlug string) (stri
NodeID string `json:"node_id"`
}
}
err := client.REST("GET", fmt.Sprintf("orgs/%s/teams/%s", orgName, teamSlug), nil, &response)
// TODO: GHE support
err := client.REST(ghinstance.Default(), "GET", fmt.Sprintf("orgs/%s/teams/%s", orgName, teamSlug), nil, &response)
return response.Organization.NodeID, response.NodeID, err
}

View file

@ -120,21 +120,17 @@ func creditsRun(opts *CreditsOptions) error {
client := api.NewClientFromHTTP(httpClient)
var owner string
var repo string
var baseRepo ghrepo.Interface
if opts.Repository == "" {
baseRepo, err := opts.BaseRepo()
baseRepo, err = opts.BaseRepo()
if err != nil {
return err
}
owner = baseRepo.RepoOwner()
repo = baseRepo.RepoName()
} else {
parts := strings.SplitN(opts.Repository, "/", 2)
owner = parts[0]
repo = parts[1]
baseRepo, err = ghrepo.FromFullName(opts.Repository)
if err != nil {
return err
}
}
type Contributor struct {
@ -145,9 +141,9 @@ func creditsRun(opts *CreditsOptions) error {
result := Result{}
body := bytes.NewBufferString("")
path := fmt.Sprintf("repos/%s/%s/contributors", owner, repo)
path := fmt.Sprintf("repos/%s/%s/contributors", baseRepo.RepoOwner(), baseRepo.RepoName())
err = client.REST("GET", path, body, &result)
err = client.REST(baseRepo.RepoHost(), "GET", path, body, &result)
if err != nil {
return err
}

View file

@ -24,7 +24,7 @@ func RepositoryReadme(client *http.Client, repo ghrepo.Interface) (*RepoReadme,
Content string
}
err := apiClient.REST("GET", fmt.Sprintf("repos/%s/readme", ghrepo.FullName(repo)), nil, &response)
err := apiClient.REST(repo.RepoHost(), "GET", fmt.Sprintf("repos/%s/readme", ghrepo.FullName(repo)), nil, &response)
if err != nil {
var httpError api.HTTPError
if errors.As(err, &httpError) && httpError.StatusCode == 404 {

View file

@ -6,6 +6,7 @@ import (
"time"
"github.com/cli/cli/api"
"github.com/cli/cli/internal/ghinstance"
"github.com/hashicorp/go-version"
"gopkg.in/yaml.v3"
)
@ -42,7 +43,7 @@ func getLatestReleaseInfo(client *api.Client, stateFilePath, repo, currentVersio
}
var latestRelease ReleaseInfo
err = client.REST("GET", fmt.Sprintf("repos/%s/releases/latest", repo), nil, &latestRelease)
err = client.REST(ghinstance.Default(), "GET", fmt.Sprintf("repos/%s/releases/latest", repo), nil, &latestRelease)
if err != nil {
return nil, err
}

View file

@ -48,9 +48,8 @@ func RenderMarkdown(text string) (string, error) {
func Pluralize(num int, thing string) string {
if num == 1 {
return fmt.Sprintf("%d %s", num, thing)
} else {
return fmt.Sprintf("%d %ss", num, thing)
}
return fmt.Sprintf("%d %ss", num, thing)
}
func fmtDuration(amount int, unit string) string {