From 288d01318b89f1177983543aceb6df744b1a23a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 16 Jul 2020 14:38:42 +0200 Subject: [PATCH 1/6] Respect the hostname of current repository in queries --- api/client.go | 30 +++---- api/client_test.go | 8 +- api/queries_gist.go | 52 ------------ api/queries_issue.go | 12 +-- api/queries_org.go | 20 +++-- api/queries_pr.go | 36 ++++----- api/queries_repo.go | 30 +++---- api/queries_user.go | 4 +- auth/oauth.go | 14 +++- command/credits.go | 21 +++-- command/issue.go | 2 +- command/pr.go | 4 +- command/pr_review.go | 4 +- command/repo.go | 4 +- command/root.go | 139 +++++++++++++++----------------- context/blank_context.go | 15 +--- context/context.go | 53 +++--------- internal/config/config_setup.go | 6 +- internal/ghinstance/host.go | 38 +++++++++ pkg/cmd/api/http.go | 2 +- pkg/cmd/gist/create/create.go | 4 +- pkg/cmd/gist/create/http.go | 4 +- update/update.go | 3 +- 23 files changed, 229 insertions(+), 276 deletions(-) delete mode 100644 api/queries_gist.go create mode 100644 internal/ghinstance/host.go diff --git a/api/client.go b/api/client.go index 1f16cffc7..d8ed2ce99 100644 --- a/api/client.go +++ b/api/client.go @@ -11,6 +11,7 @@ import ( "regexp" "strings" + "github.com/cli/cli/internal/ghinstance" "github.com/henvic/httpretty" "github.com/shurcooL/graphql" ) @@ -43,25 +44,21 @@ func NewClientFromHTTP(httpClient *http.Client) *Client { func AddHeader(name, value string) ClientOption { return func(tr http.RoundTripper) http.RoundTripper { return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { - // prevent the token from leaking to non-GitHub hosts - // TODO: GHE support - if !strings.EqualFold(name, "Authorization") || strings.HasSuffix(req.URL.Hostname(), ".github.com") { - req.Header.Add(name, value) - } + req.Header.Add(name, value) return tr.RoundTrip(req) }} } } // AddHeaderFunc is an AddHeader that gets the string value from a function -func AddHeaderFunc(name string, value func() string) ClientOption { +func AddHeaderFunc(name string, getValue func(*http.Request) (string, error)) ClientOption { return func(tr http.RoundTripper) http.RoundTripper { return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { - // prevent the token from leaking to non-GitHub hosts - // TODO: GHE support - if !strings.EqualFold(name, "Authorization") || strings.HasSuffix(req.URL.Hostname(), ".github.com") { - req.Header.Add(name, value()) + value, err := getValue(req) + if err != nil { + return nil, err } + req.Header.Add(name, value) return tr.RoundTrip(req) }} } @@ -238,14 +235,13 @@ func (c Client) HasScopes(wantedScopes ...string) (bool, string, error) { } // GraphQL performs a GraphQL request and parses the response -func (c Client) GraphQL(query string, variables map[string]interface{}, data interface{}) error { - url := "https://api.github.com/graphql" +func (c Client) GraphQL(hostname string, query string, variables map[string]interface{}, data interface{}) error { reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables}) if err != nil { return err } - req, err := http.NewRequest("POST", url, bytes.NewBuffer(reqBody)) + req, err := http.NewRequest("POST", ghinstance.GraphQLEndpoint(hostname), bytes.NewBuffer(reqBody)) if err != nil { return err } @@ -261,13 +257,13 @@ func (c Client) GraphQL(query string, variables map[string]interface{}, data int return handleResponse(resp, data) } -func graphQLClient(h *http.Client) *graphql.Client { - return graphql.NewClient("https://api.github.com/graphql", h) +func graphQLClient(h *http.Client, hostname string) *graphql.Client { + return graphql.NewClient(ghinstance.GraphQLEndpoint(hostname), h) } // REST performs a REST request and parses the response. -func (c Client) REST(method string, p string, body io.Reader, data interface{}) error { - url := "https://api.github.com/" + p +func (c Client) REST(hostname string, method string, p string, body io.Reader, data interface{}) error { + url := ghinstance.RESTPrefix(hostname) + p req, err := http.NewRequest(method, url, body) if err != nil { return err diff --git a/api/client_test.go b/api/client_test.go index 7307ce2b6..063d4648c 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -33,7 +33,7 @@ func TestGraphQL(t *testing.T) { }{} http.StubResponse(200, bytes.NewBufferString(`{"data":{"viewer":{"login":"hubot"}}}`)) - err := client.GraphQL("QUERY", vars, &response) + err := client.GraphQL("github.com", "QUERY", vars, &response) eq(t, err, nil) eq(t, response.Viewer.Login, "hubot") @@ -55,7 +55,7 @@ func TestGraphQLError(t *testing.T) { ] }`)) - err := client.GraphQL("", nil, &response) + err := client.GraphQL("github.com", "", nil, &response) if err == nil || err.Error() != "GraphQL error: OH NO\nthis is fine" { t.Fatalf("got %q", err.Error()) } @@ -71,7 +71,7 @@ func TestRESTGetDelete(t *testing.T) { http.StubResponse(204, bytes.NewBuffer([]byte{})) r := bytes.NewReader([]byte(`{}`)) - err := client.REST("DELETE", "applications/CLIENTID/grant", r, nil) + err := client.REST("github.com", "DELETE", "applications/CLIENTID/grant", r, nil) eq(t, err, nil) } @@ -82,7 +82,7 @@ func TestRESTError(t *testing.T) { http.StubResponse(422, bytes.NewBufferString(`{"message": "OH NO"}`)) var httpErr HTTPError - err := client.REST("DELETE", "repos/branch", nil, nil) + err := client.REST("github.com", "DELETE", "repos/branch", nil, nil) if err == nil || !errors.As(err, &httpErr) { t.Fatalf("got %v", err) } diff --git a/api/queries_gist.go b/api/queries_gist.go deleted file mode 100644 index 7a9dc14b6..000000000 --- a/api/queries_gist.go +++ /dev/null @@ -1,52 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" -) - -// Gist represents a GitHub's gist. -type Gist struct { - Description string `json:"description,omitempty"` - Public bool `json:"public,omitempty"` - Files map[GistFilename]GistFile `json:"files,omitempty"` - HTMLURL string `json:"html_url,omitempty"` -} - -type GistFilename string - -type GistFile struct { - Content string `json:"content,omitempty"` -} - -// Create a gist for authenticated user. -// -// GitHub API docs: https://developer.github.com/v3/gists/#create-a-gist -func GistCreate(client *Client, description string, public bool, files map[string]string) (*Gist, error) { - gistFiles := map[GistFilename]GistFile{} - - for filename, content := range files { - gistFiles[GistFilename(filename)] = GistFile{content} - } - - path := "gists" - body := &Gist{ - Description: description, - Public: public, - Files: gistFiles, - } - result := Gist{} - - requestByte, err := json.Marshal(body) - if err != nil { - return nil, err - } - requestBody := bytes.NewReader(requestByte) - - err = client.REST("POST", path, requestBody, &result) - if err != nil { - return nil, err - } - - return &result, nil -} diff --git a/api/queries_issue.go b/api/queries_issue.go index 1b3bddd64..4f2c0eb14 100644 --- a/api/queries_issue.go +++ b/api/queries_issue.go @@ -112,7 +112,7 @@ func IssueCreate(client *Client, repo *Repository, params map[string]interface{} } }{} - err := client.GraphQL(query, variables, &result) + err := client.GraphQL(repo.RepoHost(), query, variables, &result) if err != nil { return nil, err } @@ -171,7 +171,7 @@ func IssueStatus(client *Client, repo ghrepo.Interface, currentUsername string) } var resp response - err := client.GraphQL(query, variables, &resp) + err := client.GraphQL(repo.RepoHost(), query, variables, &resp) if err != nil { return nil, err } @@ -270,7 +270,7 @@ func IssueList(client *Client, repo ghrepo.Interface, state string, labels []str loop: for { variables["limit"] = pageLimit - err := client.GraphQL(query, variables, &response) + err := client.GraphQL(repo.RepoHost(), query, variables, &response) if err != nil { return nil, err } @@ -361,7 +361,7 @@ func IssueByNumber(client *Client, repo ghrepo.Interface, number int) (*Issue, e } var resp response - err := client.GraphQL(query, variables, &resp) + err := client.GraphQL(repo.RepoHost(), query, variables, &resp) if err != nil { return nil, err } @@ -389,7 +389,7 @@ func IssueClose(client *Client, repo ghrepo.Interface, issue Issue) error { }, } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) err := gql.MutateNamed(context.Background(), "IssueClose", &mutation, variables) if err != nil { @@ -414,7 +414,7 @@ func IssueReopen(client *Client, repo ghrepo.Interface, issue Issue) error { }, } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) err := gql.MutateNamed(context.Background(), "IssueReopen", &mutation, variables) return err diff --git a/api/queries_org.go b/api/queries_org.go index b91c2da78..54f05a70f 100644 --- a/api/queries_org.go +++ b/api/queries_org.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "github.com/cli/cli/internal/ghinstance" + "github.com/cli/cli/internal/ghrepo" "github.com/shurcooL/githubv4" ) @@ -12,7 +14,8 @@ func resolveOrganization(client *Client, orgName string) (string, error) { var response struct { NodeID string `json:"node_id"` } - err := client.REST("GET", fmt.Sprintf("users/%s", orgName), nil, &response) + // TODO: GHE support + err := client.REST(ghinstance.Default(), "GET", fmt.Sprintf("users/%s", orgName), nil, &response) return response.NodeID, err } @@ -24,12 +27,13 @@ func resolveOrganizationTeam(client *Client, orgName, teamSlug string) (string, NodeID string `json:"node_id"` } } - err := client.REST("GET", fmt.Sprintf("orgs/%s/teams/%s", orgName, teamSlug), nil, &response) + // TODO: GHE support + err := client.REST(ghinstance.Default(), "GET", fmt.Sprintf("orgs/%s/teams/%s", orgName, teamSlug), nil, &response) return response.Organization.NodeID, response.NodeID, err } // OrganizationProjects fetches all open projects for an organization -func OrganizationProjects(client *Client, owner string) ([]RepoProject, error) { +func OrganizationProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, error) { var query struct { Organization struct { Projects struct { @@ -43,11 +47,11 @@ func OrganizationProjects(client *Client, owner string) ([]RepoProject, error) { } variables := map[string]interface{}{ - "owner": githubv4.String(owner), + "owner": githubv4.String(repo.RepoOwner()), "endCursor": (*githubv4.String)(nil), } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) var projects []RepoProject for { @@ -72,7 +76,7 @@ type OrgTeam struct { } // OrganizationTeams fetches all the teams in an organization -func OrganizationTeams(client *Client, owner string) ([]OrgTeam, error) { +func OrganizationTeams(client *Client, repo ghrepo.Interface) ([]OrgTeam, error) { var query struct { Organization struct { Teams struct { @@ -86,11 +90,11 @@ func OrganizationTeams(client *Client, owner string) ([]OrgTeam, error) { } variables := map[string]interface{}{ - "owner": githubv4.String(owner), + "owner": githubv4.String(repo.RepoOwner()), "endCursor": (*githubv4.String)(nil), } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) var teams []OrgTeam for { diff --git a/api/queries_pr.go b/api/queries_pr.go index 4437f8f7e..642e83c34 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -363,7 +363,7 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu } var resp response - err := client.GraphQL(query, variables, &resp) + err := client.GraphQL(repo.RepoHost(), query, variables, &resp) if err != nil { return nil, err } @@ -500,7 +500,7 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu } var resp response - err := client.GraphQL(query, variables, &resp) + err := client.GraphQL(repo.RepoHost(), query, variables, &resp) if err != nil { return nil, err } @@ -613,7 +613,7 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea } var resp response - err := client.GraphQL(query, variables, &resp) + err := client.GraphQL(repo.RepoHost(), query, variables, &resp) if err != nil { return nil, err } @@ -663,7 +663,7 @@ func CreatePullRequest(client *Client, repo *Repository, params map[string]inter } }{} - err := client.GraphQL(query, variables, &result) + err := client.GraphQL(repo.RepoHost(), query, variables, &result) if err != nil { return nil, err } @@ -689,7 +689,7 @@ func CreatePullRequest(client *Client, repo *Repository, params map[string]inter variables := map[string]interface{}{ "input": updateParams, } - err := client.GraphQL(updateQuery, variables, &result) + err := client.GraphQL(repo.RepoHost(), updateQuery, variables, &result) if err != nil { return nil, err } @@ -714,7 +714,7 @@ func CreatePullRequest(client *Client, repo *Repository, params map[string]inter variables := map[string]interface{}{ "input": reviewParams, } - err := client.GraphQL(reviewQuery, variables, &result) + err := client.GraphQL(repo.RepoHost(), reviewQuery, variables, &result) if err != nil { return nil, err } @@ -734,7 +734,7 @@ func isBlank(v interface{}) bool { } } -func AddReview(client *Client, pr *PullRequest, input *PullRequestReviewInput) error { +func AddReview(client *Client, repo ghrepo.Interface, pr *PullRequest, input *PullRequestReviewInput) error { var mutation struct { AddPullRequestReview struct { ClientMutationID string @@ -758,11 +758,11 @@ func AddReview(client *Client, pr *PullRequest, input *PullRequestReviewInput) e }, } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) return gql.MutateNamed(context.Background(), "PullRequestReviewAdd", &mutation, variables) } -func PullRequestList(client *Client, vars map[string]interface{}, limit int) (*PullRequestAndTotalCount, error) { +func PullRequestList(client *Client, repo ghrepo.Interface, vars map[string]interface{}, limit int) (*PullRequestAndTotalCount, error) { type prBlock struct { Edges []struct { Node PullRequest @@ -859,10 +859,8 @@ func PullRequestList(client *Client, vars map[string]interface{}, limit int) (*P } } }` - owner := vars["owner"].(string) - repo := vars["repo"].(string) search := []string{ - fmt.Sprintf("repo:%s/%s", owner, repo), + fmt.Sprintf("repo:%s/%s", repo.RepoOwner(), repo.RepoName()), fmt.Sprintf("assignee:%s", assignee), "is:pr", "sort:created-desc", @@ -888,6 +886,8 @@ func PullRequestList(client *Client, vars map[string]interface{}, limit int) (*P } variables["q"] = strings.Join(search, " ") } else { + variables["owner"] = repo.RepoOwner() + variables["repo"] = repo.RepoName() for name, val := range vars { variables[name] = val } @@ -896,7 +896,7 @@ loop: for { variables["limit"] = pageLimit var data response - err := client.GraphQL(query, variables, &data) + err := client.GraphQL(repo.RepoHost(), query, variables, &data) if err != nil { return nil, err } @@ -945,7 +945,7 @@ func PullRequestClose(client *Client, repo ghrepo.Interface, pr *PullRequest) er }, } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) err := gql.MutateNamed(context.Background(), "PullRequestClose", &mutation, variables) return err @@ -966,7 +966,7 @@ func PullRequestReopen(client *Client, repo ghrepo.Interface, pr *PullRequest) e }, } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) err := gql.MutateNamed(context.Background(), "PullRequestReopen", &mutation, variables) return err @@ -996,7 +996,7 @@ func PullRequestMerge(client *Client, repo ghrepo.Interface, pr *PullRequest, m }, } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) err := gql.MutateNamed(context.Background(), "PullRequestMerge", &mutation, variables) return err @@ -1017,13 +1017,13 @@ func PullRequestReady(client *Client, repo ghrepo.Interface, pr *PullRequest) er }, } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) return gql.MutateNamed(context.Background(), "PullRequestReadyForReview", &mutation, variables) } func BranchDeleteRemote(client *Client, repo ghrepo.Interface, branch string) error { path := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", repo.RepoOwner(), repo.RepoName(), branch) - return client.REST("DELETE", path, nil, nil) + return client.REST(repo.RepoHost(), "DELETE", path, nil, nil) } func min(a, b int) int { diff --git a/api/queries_repo.go b/api/queries_repo.go index 35079095c..40ebb9955 100644 --- a/api/queries_repo.go +++ b/api/queries_repo.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/cli/cli/internal/ghinstance" "github.com/cli/cli/internal/ghrepo" "github.com/shurcooL/githubv4" ) @@ -104,7 +105,7 @@ func GitHubRepo(client *Client, repo ghrepo.Interface) (*Repository, error) { result := struct { Repository Repository }{} - err := client.GraphQL(query, variables, &result) + err := client.GraphQL(repo.RepoHost(), query, variables, &result) if err != nil { return nil, err @@ -143,7 +144,7 @@ func RepoParent(client *Client, repo ghrepo.Interface) (ghrepo.Interface, error) "name": githubv4.String(repo.RepoName()), } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) err := gql.QueryNamed(context.Background(), "RepositoryFindParent", &query, variables) if err != nil { return nil, err @@ -187,7 +188,7 @@ func RepoNetwork(client *Client, repos []ghrepo.Interface) (RepoNetworkResult, e graphqlResult := make(map[string]*json.RawMessage) var result RepoNetworkResult - err := client.GraphQL(fmt.Sprintf(` + err := client.GraphQL(hostname, fmt.Sprintf(` fragment repo on Repository { id name @@ -283,7 +284,7 @@ func ForkRepo(client *Client, repo ghrepo.Interface) (*Repository, error) { path := fmt.Sprintf("repos/%s/forks", ghrepo.FullName(repo)) body := bytes.NewBufferString(`{}`) result := repositoryV3{} - err := client.REST("POST", path, body, &result) + err := client.REST(repo.RepoHost(), "POST", path, body, &result) if err != nil { return nil, err } @@ -316,7 +317,7 @@ func RepoFindFork(client *Client, repo ghrepo.Interface) (*Repository, error) { "repo": repo.RepoName(), } - if err := client.GraphQL(` + if err := client.GraphQL(repo.RepoHost(), ` query RepositoryFindFork($owner: String!, $repo: String!) { repository(owner: $owner, name: $repo) { forks(first: 1, affiliations: [OWNER, COLLABORATOR]) { @@ -385,7 +386,8 @@ func RepoCreate(client *Client, input RepoCreateInput) (*Repository, error) { "input": input, } - err := client.GraphQL(` + // TODO: GHE support + err := client.GraphQL(ghinstance.Default(), ` mutation RepositoryCreate($input: CreateRepositoryInput!) { createRepository(input: $input) { repository { @@ -416,7 +418,7 @@ func RepositoryReadme(client *Client, repo ghrepo.Interface) (*RepoReadme, error Content string } - err := client.REST("GET", fmt.Sprintf("repos/%s/readme", ghrepo.FullName(repo)), nil, &response) + err := client.REST(repo.RepoHost(), "GET", fmt.Sprintf("repos/%s/readme", ghrepo.FullName(repo)), nil, &response) if err != nil { var httpError HTTPError if errors.As(err, &httpError) && httpError.StatusCode == 404 { @@ -554,7 +556,7 @@ func RepoMetadata(client *Client, repo ghrepo.Interface, input RepoMetadataInput if input.Reviewers { count++ go func() { - teams, err := OrganizationTeams(client, repo.RepoOwner()) + teams, err := OrganizationTeams(client, repo) // TODO: better detection of non-org repos if err != nil && !strings.HasPrefix(err.Error(), "Could not resolve to an Organization") { errc <- fmt.Errorf("error fetching organization teams: %w", err) @@ -585,7 +587,7 @@ func RepoMetadata(client *Client, repo ghrepo.Interface, input RepoMetadataInput } result.Projects = projects - orgProjects, err := OrganizationProjects(client, repo.RepoOwner()) + orgProjects, err := OrganizationProjects(client, repo) // TODO: better detection of non-org repos if err != nil && !strings.HasPrefix(err.Error(), "Could not resolve to an Organization") { errc <- fmt.Errorf("error fetching organization projects: %w", err) @@ -681,7 +683,7 @@ func RepoResolveMetadataIDs(client *Client, repo ghrepo.Interface, input RepoRes fmt.Fprint(query, "}\n") response := make(map[string]json.RawMessage) - err = client.GraphQL(query.String(), nil, &response) + err = client.GraphQL(repo.RepoHost(), query.String(), nil, &response) if err != nil { return result, err } @@ -744,7 +746,7 @@ func RepoProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, error) "endCursor": (*githubv4.String)(nil), } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) var projects []RepoProject for { @@ -788,7 +790,7 @@ func RepoAssignableUsers(client *Client, repo ghrepo.Interface) ([]RepoAssignee, "endCursor": (*githubv4.String)(nil), } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) var users []RepoAssignee for { @@ -832,7 +834,7 @@ func RepoLabels(client *Client, repo ghrepo.Interface) ([]RepoLabel, error) { "endCursor": (*githubv4.String)(nil), } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) var labels []RepoLabel for { @@ -876,7 +878,7 @@ func RepoMilestones(client *Client, repo ghrepo.Interface) ([]RepoMilestone, err "endCursor": (*githubv4.String)(nil), } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, repo.RepoHost()) var milestones []RepoMilestone for { diff --git a/api/queries_user.go b/api/queries_user.go index ea31c59ea..0a9b68bd2 100644 --- a/api/queries_user.go +++ b/api/queries_user.go @@ -4,13 +4,13 @@ import ( "context" ) -func CurrentLoginName(client *Client) (string, error) { +func CurrentLoginName(client *Client, hostname string) (string, error) { var query struct { Viewer struct { Login string } } - gql := graphQLClient(client.http) + gql := graphQLClient(client.http, hostname) err := gql.QueryNamed(context.Background(), "UserCurrent", &query, nil) return query.Viewer.Login, err } diff --git a/auth/oauth.go b/auth/oauth.go index 3568956c5..1d421fa17 100644 --- a/auth/oauth.go +++ b/auth/oauth.go @@ -13,6 +13,7 @@ import ( "os" "strings" + "github.com/cli/cli/internal/ghinstance" "github.com/cli/cli/pkg/browser" ) @@ -52,9 +53,18 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) { scopes = strings.Join(oa.Scopes, " ") } + localhost := "127.0.0.1" + callbackPath := "/callback" + if ghinstance.IsEnterprise(oa.Hostname) { + // the OAuth app on Enterprise hosts is still registered with a legacy callback URL + // see https://github.com/cli/cli/pull/222, https://github.com/cli/cli/pull/650 + localhost = "localhost" + callbackPath = "/" + } + q := url.Values{} q.Set("client_id", oa.ClientID) - q.Set("redirect_uri", fmt.Sprintf("http://127.0.0.1:%d/callback", port)) + q.Set("redirect_uri", fmt.Sprintf("http://%s:%d%s", localhost, port, callbackPath)) q.Set("scope", scopes) q.Set("state", state) @@ -73,7 +83,7 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) { _ = http.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { oa.logf("server handler: %s\n", r.URL.Path) - if r.URL.Path != "/callback" { + if r.URL.Path != callbackPath { w.WriteHeader(404) return } diff --git a/command/credits.go b/command/credits.go index 0104edb02..421890f1f 100644 --- a/command/credits.go +++ b/command/credits.go @@ -15,6 +15,7 @@ import ( "github.com/spf13/cobra" "golang.org/x/crypto/ssh/terminal" + "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/utils" ) @@ -73,21 +74,17 @@ func credits(cmd *cobra.Command, args []string) error { return err } - var owner string - var repo string - + var baseRepo ghrepo.Interface if len(args) == 0 { - baseRepo, err := determineBaseRepo(client, cmd, ctx) + baseRepo, err = determineBaseRepo(client, cmd, ctx) if err != nil { return err } - - owner = baseRepo.RepoOwner() - repo = baseRepo.RepoName() } else { - parts := strings.SplitN(args[0], "/", 2) - owner = parts[0] - repo = parts[1] + baseRepo, err = ghrepo.FromFullName(args[0]) + if err != nil { + return err + } } type Contributor struct { @@ -98,9 +95,9 @@ func credits(cmd *cobra.Command, args []string) error { result := Result{} body := bytes.NewBufferString("") - path := fmt.Sprintf("repos/%s/%s/contributors", owner, repo) + path := fmt.Sprintf("repos/%s/%s/contributors", baseRepo.RepoOwner(), baseRepo.RepoName()) - err = client.REST("GET", path, body, &result) + err = client.REST(baseRepo.RepoHost(), "GET", path, body, &result) if err != nil { return err } diff --git a/command/issue.go b/command/issue.go index 5de86da91..84ddaf8c8 100644 --- a/command/issue.go +++ b/command/issue.go @@ -281,7 +281,7 @@ func issueStatus(cmd *cobra.Command, args []string) error { return err } - currentUser, err := api.CurrentLoginName(apiClient) + currentUser, err := api.CurrentLoginName(apiClient, baseRepo.RepoHost()) if err != nil { return err } diff --git a/command/pr.go b/command/pr.go index 9978d18be..f385fc6c6 100644 --- a/command/pr.go +++ b/command/pr.go @@ -265,8 +265,6 @@ func prList(cmd *cobra.Command, args []string) error { } params := map[string]interface{}{ - "owner": baseRepo.RepoOwner(), - "repo": baseRepo.RepoName(), "state": graphqlState, } if len(labels) > 0 { @@ -279,7 +277,7 @@ func prList(cmd *cobra.Command, args []string) error { params["assignee"] = assignee } - listResult, err := api.PullRequestList(apiClient, params, limit) + listResult, err := api.PullRequestList(apiClient, baseRepo, params, limit) if err != nil { return err } diff --git a/command/pr_review.go b/command/pr_review.go index a9681d619..55c402e3f 100644 --- a/command/pr_review.go +++ b/command/pr_review.go @@ -99,7 +99,7 @@ func prReview(cmd *cobra.Command, args []string) error { return err } - pr, _, err := prFromArgs(ctx, apiClient, cmd, args) + pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args) if err != nil { return err } @@ -122,7 +122,7 @@ func prReview(cmd *cobra.Command, args []string) error { } } - err = api.AddReview(apiClient, pr, reviewData) + err = api.AddReview(apiClient, baseRepo, pr, reviewData) if err != nil { return fmt.Errorf("failed to create review: %w", err) } diff --git a/command/repo.go b/command/repo.go index d705a4a5f..a8a8630d5 100644 --- a/command/repo.go +++ b/command/repo.go @@ -14,6 +14,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/api" "github.com/cli/cli/git" + "github.com/cli/cli/internal/ghinstance" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/internal/run" "github.com/cli/cli/utils" @@ -181,7 +182,8 @@ func repoClone(cmd *cobra.Command, args []string) error { cloneURL := args[0] if !strings.Contains(cloneURL, ":") { if !strings.Contains(cloneURL, "/") { - currentUser, err := api.CurrentLoginName(apiClient) + // TODO: GHE support + currentUser, err := api.CurrentLoginName(apiClient, ghinstance.Default()) if err != nil { return err } diff --git a/command/root.go b/command/root.go index 01f7d524a..06fed7510 100644 --- a/command/root.go +++ b/command/root.go @@ -17,6 +17,7 @@ import ( "github.com/cli/cli/api" "github.com/cli/cli/context" "github.com/cli/cli/internal/config" + "github.com/cli/cli/internal/ghinstance" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/internal/run" apiCmd "github.com/cli/cli/pkg/cmd/api" @@ -30,9 +31,6 @@ import ( "github.com/spf13/pflag" ) -// TODO these are sprinkled across command, context, config, and ghrepo -const defaultHostname = "github.com" - // Version is dynamically set by the toolchain or overridden by the Makefile. var Version = "DEV" @@ -41,6 +39,8 @@ var BuildDate = "" // YYYY-MM-DD var versionOutput = "" +var defaultStreams *iostreams.IOStreams + func init() { if Version == "DEV" { if info, ok := debug.ReadBuildInfo(); ok && info.Main.Version != "(devel)" { @@ -72,22 +72,21 @@ func init() { return &cmdutil.FlagError{Err: err} }) + defaultStreams = iostreams.System() + // TODO: iron out how a factory incorporates context cmdFactory := &cmdutil.Factory{ - IOStreams: iostreams.System(), + IOStreams: defaultStreams, HttpClient: func() (*http.Client, error) { - token := os.Getenv("GITHUB_TOKEN") - if len(token) == 0 { - // TODO: decouple from `context` - ctx := context.New() - var err error - // TODO: pass IOStreams to this so that the auth flow knows if it's interactive or not - token, err = ctx.AuthToken() - if err != nil { - return nil, err - } + // TODO: decouple from `context` + ctx := context.New() + cfg, err := ctx.Config() + if err != nil { + return nil, err } - return httpClient(token), nil + + // TODO: avoid setting Accept header for `api` command + return httpClient(defaultStreams, cfg, true), nil }, BaseRepo: func() (ghrepo.Interface, error) { // TODO: decouple from `context` @@ -158,14 +157,11 @@ var initContext = func() context.Context { if repo := os.Getenv("GH_REPO"); repo != "" { ctx.SetBaseRepo(repo) } - if token := os.Getenv("GITHUB_TOKEN"); token != "" { - ctx.SetAuthToken(token) - } return ctx } -// BasicClient returns an API client that borrows from but does not depend on -// user configuration +// BasicClient returns an API client for github.com only that borrows from but +// does not depend on user configuration func BasicClient() (*api.Client, error) { var opts []api.ClientOption if verbose := os.Getenv("DEBUG"); verbose != "" { @@ -176,7 +172,7 @@ func BasicClient() (*api.Client, error) { token := os.Getenv("GITHUB_TOKEN") if token == "" { if c, err := config.ParseDefaultConfig(); err == nil { - token, _ = c.Get(defaultHostname, "oauth_token") + token, _ = c.Get(ghinstance.Default(), "oauth_token") } } if token != "" { @@ -193,72 +189,68 @@ func contextForCommand(cmd *cobra.Command) context.Context { return ctx } -// for cmdutil-powered commands -func httpClient(token string) *http.Client { +// generic authenticated HTTP client for commands +func httpClient(io *iostreams.IOStreams, cfg config.Config, setAccept bool) *http.Client { var opts []api.ClientOption if verbose := os.Getenv("DEBUG"); verbose != "" { opts = append(opts, apiVerboseLog()) } + opts = append(opts, - api.AddHeader("Authorization", fmt.Sprintf("token %s", token)), api.AddHeader("User-Agent", fmt.Sprintf("GitHub CLI %s", Version)), + api.AddHeaderFunc("Authorization", func(req *http.Request) (string, error) { + if token := os.Getenv("GITHUB_TOKEN"); token != "" { + return fmt.Sprintf("token %s", token), nil + } + + hostname := ghinstance.NormalizeHostname(req.URL.Hostname()) + token, err := cfg.Get(hostname, "oauth_token") + if token == "" { + var notFound *config.NotFoundError + // TODO: check if stdout is TTY too + if errors.As(err, ¬Found) && io.IsStdinTTY() { + // interactive OAuth flow + token, err = config.AuthFlowWithConfig(cfg, hostname, "Notice: authentication required") + } + if err != nil { + return "", err + } + if token == "" { + // TODO: instruct user how to manually authenticate + return "", fmt.Errorf("authentication required for %s", hostname) + } + } + + return fmt.Sprintf("token %s", token), nil + }), ) + + if setAccept { + opts = append(opts, + api.AddHeaderFunc("Accept", func(req *http.Request) (string, error) { + // antiope-preview: Checks + accept := "application/vnd.github.antiope-preview+json" + if ghinstance.IsEnterprise(req.URL.Hostname()) { + // shadow-cat-preview: Draft pull requests + accept += ", application/vnd.github.shadow-cat-preview" + } + return accept, nil + }), + ) + } + return api.NewHTTPClient(opts...) } -// overridden in tests +// LEGACY; overridden in tests var apiClientForContext = func(ctx context.Context) (*api.Client, error) { - token, err := ctx.AuthToken() + cfg, err := ctx.Config() if err != nil { return nil, err } - var opts []api.ClientOption - if verbose := os.Getenv("DEBUG"); verbose != "" { - opts = append(opts, apiVerboseLog()) - } - - getAuthValue := func() string { - return fmt.Sprintf("token %s", token) - } - - tokenFromEnv := func() bool { - return os.Getenv("GITHUB_TOKEN") == token - } - - checkScopesFunc := func(appID string) error { - if config.IsGitHubApp(appID) && !tokenFromEnv() && utils.IsTerminal(os.Stdin) && utils.IsTerminal(os.Stderr) { - cfg, err := ctx.Config() - if err != nil { - return err - } - newToken, err := config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: additional authorization required") - if err != nil { - return err - } - // update configuration in memory - token = newToken - } else { - fmt.Fprintln(os.Stderr, "Warning: gh now requires the `read:org` OAuth scope.") - fmt.Fprintln(os.Stderr, "Visit https://github.com/settings/tokens and edit your token to enable `read:org`") - if tokenFromEnv() { - fmt.Fprintln(os.Stderr, "or generate a new token for the GITHUB_TOKEN environment variable") - } else { - fmt.Fprintln(os.Stderr, "or generate a new token and paste it via `gh config set -h github.com oauth_token MYTOKEN`") - } - } - return nil - } - - opts = append(opts, - api.CheckScopes("read:org", checkScopesFunc), - api.AddHeaderFunc("Authorization", getAuthValue), - api.AddHeader("User-Agent", fmt.Sprintf("GitHub CLI %s", Version)), - // antiope-preview: Checks - api.AddHeader("Accept", "application/vnd.github.antiope-preview+json"), - ) - - return api.NewClient(opts...), nil + http := httpClient(defaultStreams, cfg, true) + return api.NewClientFromHTTP(http), nil } func apiVerboseLog() api.ClientOption { @@ -348,7 +340,8 @@ func determineEditor(cmd *cobra.Command) (string, error) { if err != nil { return "", fmt.Errorf("could not read config: %w", err) } - editorCommand, _ = cfg.Get(defaultHostname, "editor") + // TODO: consider supporting setting an editor per GHE host + editorCommand, _ = cfg.Get(ghinstance.Default(), "editor") } return editorCommand, nil diff --git a/context/blank_context.go b/context/blank_context.go index 3035b4d21..f14310937 100644 --- a/context/blank_context.go +++ b/context/blank_context.go @@ -16,10 +16,9 @@ func NewBlank() *blankContext { // A Context implementation that queries the filesystem type blankContext struct { - authToken string - branch string - baseRepo ghrepo.Interface - remotes Remotes + branch string + baseRepo ghrepo.Interface + remotes Remotes } func (c *blankContext) Config() (config.Config, error) { @@ -30,14 +29,6 @@ func (c *blankContext) Config() (config.Config, error) { return cfg, nil } -func (c *blankContext) AuthToken() (string, error) { - return c.authToken, nil -} - -func (c *blankContext) SetAuthToken(t string) { - c.authToken = t -} - func (c *blankContext) Branch() (string, error) { if c.branch == "" { return "", fmt.Errorf("branch was not initialized: %w", git.ErrNotOnAnyBranch) diff --git a/context/context.go b/context/context.go index e83e0b44e..6cff06b4c 100644 --- a/context/context.go +++ b/context/context.go @@ -13,13 +13,8 @@ import ( "github.com/cli/cli/internal/ghrepo" ) -// TODO these are sprinkled across command, context, config, and ghrepo -const defaultHostname = "github.com" - // Context represents the interface for querying information about the current environment type Context interface { - AuthToken() (string, error) - SetAuthToken(string) Branch() (string, error) SetBranch(string) Remotes() (Remotes, error) @@ -164,11 +159,10 @@ func New() Context { // A Context implementation that queries the filesystem type fsContext struct { - config config.Config - remotes Remotes - branch string - baseRepo ghrepo.Interface - authToken string + config config.Config + remotes Remotes + branch string + baseRepo ghrepo.Interface } func (c *fsContext) Config() (config.Config, error) { @@ -180,37 +174,10 @@ func (c *fsContext) Config() (config.Config, error) { return nil, err } c.config = cfg - c.authToken = "" } return c.config, nil } -func (c *fsContext) AuthToken() (string, error) { - if c.authToken != "" { - return c.authToken, nil - } - - cfg, err := c.Config() - if err != nil { - return "", err - } - - var notFound *config.NotFoundError - token, err := cfg.Get(defaultHostname, "oauth_token") - if token == "" || errors.As(err, ¬Found) { - // interactive OAuth flow - return config.AuthFlowWithConfig(cfg, defaultHostname, "Notice: authentication required") - } else if err != nil { - return "", err - } - - return token, nil -} - -func (c *fsContext) SetAuthToken(t string) { - c.authToken = t -} - func (c *fsContext) Branch() (string, error) { if c.branch != "" { return c.branch, nil @@ -242,11 +209,16 @@ func (c *fsContext) Remotes() (Remotes, error) { sshTranslate := git.ParseSSHConfig().Translator() resolvedRemotes := translateRemotes(gitRemotes, sshTranslate) - // ignore non-github.com remotes - // TODO: GHE compatibility + // determine hostname by looking at the "main" remote + var hostname string + if mainRemote, err := resolvedRemotes.FindByName("upstream", "github", "origin", "*"); err == nil { + hostname = mainRemote.RepoHost() + } + + // filter the rest of the remotes to just that hostname filteredRemotes := Remotes{} for _, r := range resolvedRemotes { - if r.RepoHost() != defaultHostname { + if r.RepoHost() != hostname { continue } filteredRemotes = append(filteredRemotes, r) @@ -255,7 +227,6 @@ func (c *fsContext) Remotes() (Remotes, error) { } if len(c.remotes) == 0 { - // TODO: GHE compatibility return nil, errors.New("no git remote found for a github.com repository") } return c.remotes, nil diff --git a/internal/config/config_setup.go b/internal/config/config_setup.go index 7217eb3d9..0d40e1a0f 100644 --- a/internal/config/config_setup.go +++ b/internal/config/config_setup.go @@ -74,7 +74,7 @@ func authFlow(oauthHost, notice string) (string, string, error) { return "", "", err } - userLogin, err := getViewer(token) + userLogin, err := getViewer(oauthHost, token) if err != nil { return "", "", err } @@ -87,9 +87,9 @@ func AuthFlowComplete() { _ = waitForEnter(os.Stdin) } -func getViewer(token string) (string, error) { +func getViewer(hostname, token string) (string, error) { http := api.NewClient(api.AddHeader("Authorization", fmt.Sprintf("token %s", token))) - return api.CurrentLoginName(http) + return api.CurrentLoginName(http, hostname) } func waitForEnter(r io.Reader) error { diff --git a/internal/ghinstance/host.go b/internal/ghinstance/host.go new file mode 100644 index 000000000..83501cb6d --- /dev/null +++ b/internal/ghinstance/host.go @@ -0,0 +1,38 @@ +package ghinstance + +import ( + "fmt" + "strings" +) + +const defaultHostname = "github.com" + +func Default() string { + return defaultHostname +} + +func IsEnterprise(h string) bool { + return NormalizeHostname(h) != defaultHostname +} + +func NormalizeHostname(h string) string { + hostname := strings.ToLower(h) + if strings.HasSuffix(hostname, "."+defaultHostname) { + return defaultHostname + } + return hostname +} + +func GraphQLEndpoint(hostname string) string { + if IsEnterprise(hostname) { + return fmt.Sprintf("https://%s/api/graphql", hostname) + } + return "https://api.github.com/graphql" +} + +func RESTPrefix(hostname string) string { + if IsEnterprise(hostname) { + return fmt.Sprintf("https://%s/api/v3/", hostname) + } + return "https://api.github.com/" +} diff --git a/pkg/cmd/api/http.go b/pkg/cmd/api/http.go index e393bb593..1ff5b0d0b 100644 --- a/pkg/cmd/api/http.go +++ b/pkg/cmd/api/http.go @@ -13,10 +13,10 @@ import ( func httpRequest(client *http.Client, method string, p string, params interface{}, headers []string) (*http.Response, error) { var requestURL string - // TODO: GHE support if strings.Contains(p, "://") { requestURL = p } else { + // TODO: GHE support requestURL = "https://api.github.com/" + p } diff --git a/pkg/cmd/gist/create/create.go b/pkg/cmd/gist/create/create.go index d4fcd276c..8151c408b 100644 --- a/pkg/cmd/gist/create/create.go +++ b/pkg/cmd/gist/create/create.go @@ -11,6 +11,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/api" + "github.com/cli/cli/internal/ghinstance" "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/iostreams" "github.com/cli/cli/utils" @@ -105,7 +106,8 @@ func createRun(opts *CreateOptions) error { return err } - gist, err := apiCreate(httpClient, opts.Description, opts.Public, files) + // TODO: GHE support + gist, err := apiCreate(httpClient, ghinstance.Default(), opts.Description, opts.Public, files) if err != nil { var httpError api.HTTPError if errors.As(err, &httpError) { diff --git a/pkg/cmd/gist/create/http.go b/pkg/cmd/gist/create/http.go index 55ea61033..42b970221 100644 --- a/pkg/cmd/gist/create/http.go +++ b/pkg/cmd/gist/create/http.go @@ -22,7 +22,7 @@ type GistFile struct { Content string `json:"content,omitempty"` } -func apiCreate(httpClient *http.Client, description string, public bool, files map[string]string) (*Gist, error) { +func apiCreate(httpClient *http.Client, hostname string, description string, public bool, files map[string]string) (*Gist, error) { gistFiles := map[GistFilename]GistFile{} for filename, content := range files { @@ -44,7 +44,7 @@ func apiCreate(httpClient *http.Client, description string, public bool, files m requestBody := bytes.NewReader(requestByte) apiClient := api.NewClientFromHTTP(httpClient) - err = apiClient.REST("POST", path, requestBody, &result) + err = apiClient.REST(hostname, "POST", path, requestBody, &result) if err != nil { return nil, err } diff --git a/update/update.go b/update/update.go index 60abff73c..bf89a12e8 100644 --- a/update/update.go +++ b/update/update.go @@ -6,6 +6,7 @@ import ( "time" "github.com/cli/cli/api" + "github.com/cli/cli/internal/ghinstance" "github.com/hashicorp/go-version" "gopkg.in/yaml.v3" ) @@ -42,7 +43,7 @@ func getLatestReleaseInfo(client *api.Client, stateFilePath, repo, currentVersio } var latestRelease ReleaseInfo - err = client.REST("GET", fmt.Sprintf("repos/%s/releases/latest", repo), nil, &latestRelease) + err = client.REST(ghinstance.Default(), "GET", fmt.Sprintf("repos/%s/releases/latest", repo), nil, &latestRelease) if err != nil { return nil, err } From 8909a3e5c347dd6438bd92739dfed3d7bbfaf510 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 23 Jul 2020 23:11:24 +0200 Subject: [PATCH 2/6] Try to avoid CodeQL warning https://github.com/cli/cli/pull/1415/checks?check_run_id=904308295 --- pkg/browser/browser.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pkg/browser/browser.go b/pkg/browser/browser.go index 1f926c462..67ebca92f 100644 --- a/pkg/browser/browser.go +++ b/pkg/browser/browser.go @@ -20,20 +20,21 @@ func Command(url string) (*exec.Cmd, error) { // ForOS produces an exec.Cmd to open the web browser for different OS func ForOS(goos, url string) *exec.Cmd { + exe := "open" var args []string switch goos { case "darwin": - args = []string{"open"} + args = append(args, url) case "windows": - args = []string{"cmd", "/c", "start"} + exe = "cmd" r := strings.NewReplacer("&", "^&") - url = r.Replace(url) + args = append(args, "/c", "start", r.Replace(url)) default: - args = []string{"xdg-open"} + exe = "xdg-open" + args = append(args, url) } - args = append(args, url) - cmd := exec.Command(args[0], args[1:]...) + cmd := exec.Command(exe, args...) cmd.Stderr = os.Stderr return cmd } From c1c836a657b735a6581aa70819d269d0cf453caf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Mon, 27 Jul 2020 16:31:05 +0200 Subject: [PATCH 3/6] Remove hardcoded "github.com" --- api/queries_repo.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/api/queries_repo.go b/api/queries_repo.go index 745f06623..4aa5cf02e 100644 --- a/api/queries_repo.go +++ b/api/queries_repo.go @@ -386,7 +386,9 @@ func RepoCreate(client *Client, input RepoCreateInput) (*Repository, error) { } // TODO: GHE support - err := client.GraphQL(ghinstance.Default(), ` + hostname := ghinstance.Default() + + err := client.GraphQL(hostname, ` mutation RepositoryCreate($input: CreateRepositoryInput!) { createRepository(input: $input) { repository { @@ -402,8 +404,7 @@ func RepoCreate(client *Client, input RepoCreateInput) (*Repository, error) { return nil, err } - // FIXME: support Enterprise hosts - return initRepoHostname(&response.CreateRepository.Repository, "github.com"), nil + return initRepoHostname(&response.CreateRepository.Repository, hostname), nil } type RepoMetadataResult struct { From 87632791d639866e430aefaadcb98e0d24168951 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Mon, 27 Jul 2020 20:09:57 +0200 Subject: [PATCH 4/6] Add ghinstance tests --- internal/ghinstance/host.go | 3 + internal/ghinstance/host_test.go | 121 +++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 internal/ghinstance/host_test.go diff --git a/internal/ghinstance/host.go b/internal/ghinstance/host.go index 83501cb6d..c05c3b263 100644 --- a/internal/ghinstance/host.go +++ b/internal/ghinstance/host.go @@ -7,14 +7,17 @@ import ( const defaultHostname = "github.com" +// Default returns the host name of the default GitHub instance func Default() string { return defaultHostname } +// IsEnterprise reports whether a non-normalized host name looks like a GHE instance func IsEnterprise(h string) bool { return NormalizeHostname(h) != defaultHostname } +// NormalizeHostname returns the canonical host name of a GitHub instance func NormalizeHostname(h string) string { hostname := strings.ToLower(h) if strings.HasSuffix(hostname, "."+defaultHostname) { diff --git a/internal/ghinstance/host_test.go b/internal/ghinstance/host_test.go new file mode 100644 index 000000000..26515dc39 --- /dev/null +++ b/internal/ghinstance/host_test.go @@ -0,0 +1,121 @@ +package ghinstance + +import ( + "testing" +) + +func TestIsEnterprise(t *testing.T) { + tests := []struct { + host string + want bool + }{ + { + host: "github.com", + want: false, + }, + { + host: "api.github.com", + want: false, + }, + { + host: "ghe.io", + want: true, + }, + { + host: "example.com", + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.host, func(t *testing.T) { + if got := IsEnterprise(tt.host); got != tt.want { + t.Errorf("IsEnterprise() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNormalizeHostname(t *testing.T) { + tests := []struct { + host string + want string + }{ + { + host: "GitHub.com", + want: "github.com", + }, + { + host: "api.github.com", + want: "github.com", + }, + { + host: "ssh.github.com", + want: "github.com", + }, + { + host: "upload.github.com", + want: "github.com", + }, + { + host: "GHE.IO", + want: "ghe.io", + }, + { + host: "git.my.org", + want: "git.my.org", + }, + } + for _, tt := range tests { + t.Run(tt.host, func(t *testing.T) { + if got := NormalizeHostname(tt.host); got != tt.want { + t.Errorf("NormalizeHostname() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGraphQLEndpoint(t *testing.T) { + tests := []struct { + host string + want string + }{ + { + host: "github.com", + want: "https://api.github.com/graphql", + }, + { + host: "ghe.io", + want: "https://ghe.io/api/graphql", + }, + } + for _, tt := range tests { + t.Run(tt.host, func(t *testing.T) { + if got := GraphQLEndpoint(tt.host); got != tt.want { + t.Errorf("GraphQLEndpoint() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRESTPrefix(t *testing.T) { + tests := []struct { + host string + want string + }{ + { + host: "github.com", + want: "https://api.github.com/", + }, + { + host: "ghe.io", + want: "https://ghe.io/api/v3/", + }, + } + for _, tt := range tests { + t.Run(tt.host, func(t *testing.T) { + if got := RESTPrefix(tt.host); got != tt.want { + t.Errorf("RESTPrefix() = %v, want %v", got, tt.want) + } + }) + } +} From d26cd64745917f7ebb5a73c3719c657aeaa0eafa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 29 Jul 2020 16:29:39 +0200 Subject: [PATCH 5/6] Support GraphQL `operationName` in `gh api` command GraphQL supports supplying multiple queries in the `query` parameter, but an additional `operationName` parameter is then required to select the query to execute. Previously, it was impossible to pass `operationName` since it would get serialized under `variables`, but it needs to be a top-level parameter. With this change, `operationName` is a special GraphQL parameter name just like `query` already is. --- pkg/cmd/api/api.go | 3 +++ pkg/cmd/api/http.go | 2 +- pkg/cmd/api/http_test.go | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index 86914693d..8900f3f39 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -74,6 +74,9 @@ on the format of the value: - if the value starts with "@", the rest of the value is interpreted as a filename to read the value from. Pass "-" to read from standard input. +For GraphQL requests, all fields other than "query" and "operationName" are +interpreted as GraphQL variables. + Raw request body may be passed from the outside via a file specified by '--input'. Pass "-" to read from standard input. In this mode, parameters specified via '--field' flags are serialized into URL query parameters. diff --git a/pkg/cmd/api/http.go b/pkg/cmd/api/http.go index e393bb593..3efeed6e4 100644 --- a/pkg/cmd/api/http.go +++ b/pkg/cmd/api/http.go @@ -87,7 +87,7 @@ func groupGraphQLVariables(params map[string]interface{}) map[string]interface{} for key, val := range params { switch key { - case "query": + case "query", "operationName": topLevel[key] = val default: variables[key] = val diff --git a/pkg/cmd/api/http_test.go b/pkg/cmd/api/http_test.go index 0f9471111..f0768b026 100644 --- a/pkg/cmd/api/http_test.go +++ b/pkg/cmd/api/http_test.go @@ -55,6 +55,21 @@ func Test_groupGraphQLVariables(t *testing.T) { }, }, }, + { + name: "query + operationName + variables", + args: map[string]interface{}{ + "query": "query Q1{} query Q2{}", + "operationName": "Q1", + "power": 9001, + }, + want: map[string]interface{}{ + "query": "query Q1{} query Q2{}", + "operationName": "Q1", + "variables": map[string]interface{}{ + "power": 9001, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From bb6851c88b2393e7f69c2a7637f4f81897c8db4e Mon Sep 17 00:00:00 2001 From: Shubhankar Kanchan Gupta Date: Fri, 31 Jul 2020 21:52:08 +0530 Subject: [PATCH 6/6] Add "open" qualifier when listing open issues/PRs (#1457) --- command/issue.go | 2 +- command/issue_test.go | 8 ++++---- command/pr_test.go | 2 +- utils/utils.go | 3 +-- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/command/issue.go b/command/issue.go index 43049faa2..3e481ef8e 100644 --- a/command/issue.go +++ b/command/issue.go @@ -376,7 +376,7 @@ func listHeader(repoName string, itemName string, matchCount int, totalMatchCoun return fmt.Sprintf("Showing %d of %s in %s that %s your search", matchCount, utils.Pluralize(totalMatchCount, itemName), repoName, matchVerb) } - return fmt.Sprintf("Showing %d of %s in %s", matchCount, utils.Pluralize(totalMatchCount, itemName), repoName) + return fmt.Sprintf("Showing %d of %s in %s", matchCount, utils.Pluralize(totalMatchCount, fmt.Sprintf("open %s", itemName)), repoName) } func printRawIssuePreview(out io.Writer, issue *api.Issue) error { diff --git a/command/issue_test.go b/command/issue_test.go index 7e7532a1e..8cde92d6e 100644 --- a/command/issue_test.go +++ b/command/issue_test.go @@ -145,7 +145,7 @@ func TestIssueList_tty(t *testing.T) { } eq(t, output.Stderr(), ` -Showing 3 of 3 issues in OWNER/REPO +Showing 3 of 3 open issues in OWNER/REPO `) @@ -847,7 +847,7 @@ func Test_listHeader(t *testing.T) { totalMatchCount: 23, hasFilters: false, }, - want: "Showing 1 of 23 genies in REPO", + want: "Showing 1 of 23 open genies in REPO", }, { name: "one result after filters", @@ -869,7 +869,7 @@ func Test_listHeader(t *testing.T) { totalMatchCount: 1, hasFilters: false, }, - want: "Showing 1 of 1 chip in REPO", + want: "Showing 1 of 1 open chip in REPO", }, { name: "one result in total after filters", @@ -891,7 +891,7 @@ func Test_listHeader(t *testing.T) { totalMatchCount: 23, hasFilters: false, }, - want: "Showing 4 of 23 plants in REPO", + want: "Showing 4 of 23 open plants in REPO", }, { name: "multiple results after filters", diff --git a/command/pr_test.go b/command/pr_test.go index fdf314022..f11939b1b 100644 --- a/command/pr_test.go +++ b/command/pr_test.go @@ -307,7 +307,7 @@ func TestPRList(t *testing.T) { } assert.Equal(t, ` -Showing 3 of 3 pull requests in OWNER/REPO +Showing 3 of 3 open pull requests in OWNER/REPO `, output.Stderr()) diff --git a/utils/utils.go b/utils/utils.go index 5751f55dd..20a5e55a8 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -48,9 +48,8 @@ func RenderMarkdown(text string) (string, error) { func Pluralize(num int, thing string) string { if num == 1 { return fmt.Sprintf("%d %s", num, thing) - } else { - return fmt.Sprintf("%d %ss", num, thing) } + return fmt.Sprintf("%d %ss", num, thing) } func fmtDuration(amount int, unit string) string {