diff --git a/api/cache.go b/api/cache.go new file mode 100644 index 000000000..620660c15 --- /dev/null +++ b/api/cache.go @@ -0,0 +1,175 @@ +package api + +import ( + "bufio" + "bytes" + "crypto/sha256" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Client { + cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache") + return &http.Client{ + Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport), + } +} + +func isCacheableRequest(req *http.Request) bool { + if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") { + return true + } + + if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") { + return true + } + + return false +} + +func isCacheableResponse(res *http.Response) bool { + return res.StatusCode < 500 && res.StatusCode != 403 +} + +// CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time +func CacheReponse(ttl time.Duration, dir string) ClientOption { + fs := fileStorage{ + dir: dir, + ttl: ttl, + mu: &sync.RWMutex{}, + } + + return func(tr http.RoundTripper) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + if !isCacheableRequest(req) { + return tr.RoundTrip(req) + } + + key, keyErr := cacheKey(req) + if keyErr == nil { + if res, err := fs.read(key); err == nil { + res.Request = req + return res, nil + } + } + + res, err := tr.RoundTrip(req) + if err == nil && keyErr == nil && isCacheableResponse(res) { + _ = fs.store(key, res) + } + return res, err + }} + } +} + +func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) { + b := &bytes.Buffer{} + nr := io.TeeReader(r, b) + return ioutil.NopCloser(b), &readCloser{ + Reader: nr, + Closer: r, + } +} + +type readCloser struct { + io.Reader + io.Closer +} + +func cacheKey(req *http.Request) (string, error) { + h := sha256.New() + fmt.Fprintf(h, "%s:", req.Method) + fmt.Fprintf(h, "%s:", req.URL.String()) + fmt.Fprintf(h, "%s:", req.Header.Get("Accept")) + fmt.Fprintf(h, "%s:", req.Header.Get("Authorization")) + + if req.Body != nil { + var bodyCopy io.ReadCloser + req.Body, bodyCopy = copyStream(req.Body) + defer bodyCopy.Close() + if _, err := io.Copy(h, bodyCopy); err != nil { + return "", err + } + } + + digest := h.Sum(nil) + return fmt.Sprintf("%x", digest), nil +} + +type fileStorage struct { + dir string + ttl time.Duration + mu *sync.RWMutex +} + +func (fs *fileStorage) filePath(key string) string { + if len(key) >= 6 { + return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:]) + } + return filepath.Join(fs.dir, key) +} + +func (fs *fileStorage) read(key string) (*http.Response, error) { + cacheFile := fs.filePath(key) + + fs.mu.RLock() + defer fs.mu.RUnlock() + + f, err := os.Open(cacheFile) + if err != nil { + return nil, err + } + defer f.Close() + + stat, err := f.Stat() + if err != nil { + return nil, err + } + + age := time.Since(stat.ModTime()) + if age > fs.ttl { + return nil, errors.New("cache expired") + } + + body := &bytes.Buffer{} + _, err = io.Copy(body, f) + if err != nil { + return nil, err + } + + res, err := http.ReadResponse(bufio.NewReader(body), nil) + return res, err +} + +func (fs *fileStorage) store(key string, res *http.Response) error { + cacheFile := fs.filePath(key) + + fs.mu.Lock() + defer fs.mu.Unlock() + + err := os.MkdirAll(filepath.Dir(cacheFile), 0755) + if err != nil { + return err + } + + f, err := os.OpenFile(cacheFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer f.Close() + + var origBody io.ReadCloser + origBody, res.Body = copyStream(res.Body) + defer res.Body.Close() + err = res.Write(f) + res.Body = origBody + return err +} diff --git a/api/cache_test.go b/api/cache_test.go new file mode 100644 index 000000000..d1039d71b --- /dev/null +++ b/api/cache_test.go @@ -0,0 +1,89 @@ +package api + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_CacheReponse(t *testing.T) { + counter := 0 + fakeHTTP := funcTripper{ + roundTrip: func(req *http.Request) (*http.Response, error) { + counter += 1 + body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String()) + status := 200 + if req.URL.Path == "/error" { + status = 500 + } + return &http.Response{ + StatusCode: status, + Body: ioutil.NopCloser(bytes.NewBufferString(body)), + }, nil + }, + } + + cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache") + httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP), CacheReponse(time.Minute, cacheDir)) + + do := func(method, url string, body io.Reader) (string, error) { + req, err := http.NewRequest(method, url, body) + if err != nil { + return "", err + } + res, err := httpClient.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() + resBody, err := ioutil.ReadAll(res.Body) + if err != nil { + err = fmt.Errorf("ReadAll: %w", err) + } + return string(resBody), err + } + + var res string + var err error + + res, err = do("GET", "http://example.com/path", nil) + require.NoError(t, err) + assert.Equal(t, "1: GET http://example.com/path", res) + res, err = do("GET", "http://example.com/path", nil) + require.NoError(t, err) + assert.Equal(t, "1: GET http://example.com/path", res) + + res, err = do("GET", "http://example.com/path2", nil) + require.NoError(t, err) + assert.Equal(t, "2: GET http://example.com/path2", res) + + res, err = do("POST", "http://example.com/path2", nil) + require.NoError(t, err) + assert.Equal(t, "3: POST http://example.com/path2", res) + + res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) + require.NoError(t, err) + assert.Equal(t, "4: POST http://example.com/graphql", res) + res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) + require.NoError(t, err) + assert.Equal(t, "4: POST http://example.com/graphql", res) + + res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`)) + require.NoError(t, err) + assert.Equal(t, "5: POST http://example.com/graphql", res) + + res, err = do("GET", "http://example.com/error", nil) + require.NoError(t, err) + assert.Equal(t, "6: GET http://example.com/error", res) + res, err = do("GET", "http://example.com/error", nil) + require.NoError(t, err) + assert.Equal(t, "7: GET http://example.com/error", res) +} diff --git a/api/queries_pr.go b/api/queries_pr.go index 4e93cb293..34011d5be 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -241,6 +241,52 @@ func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, prNumber int) (io.Rea return resp.Body, nil } +type pullRequestFeature struct { + HasReviewDecision bool + HasStatusCheckRollup bool +} + +func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prFeatures pullRequestFeature, err error) { + if !ghinstance.IsEnterprise(hostname) { + prFeatures.HasReviewDecision = true + prFeatures.HasStatusCheckRollup = true + return + } + + var featureDetection struct { + PullRequest struct { + Fields []struct { + Name string + } `graphql:"fields(includeDeprecated: true)"` + } `graphql:"PullRequest: __type(name: \"PullRequest\")"` + Commit struct { + Fields []struct { + Name string + } `graphql:"fields(includeDeprecated: true)"` + } `graphql:"Commit: __type(name: \"Commit\")"` + } + + v4 := graphQLClient(httpClient, hostname) + err = v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil) + if err != nil { + return + } + + for _, field := range featureDetection.PullRequest.Fields { + switch field.Name { + case "reviewDecision": + prFeatures.HasReviewDecision = true + } + } + for _, field := range featureDetection.Commit.Fields { + switch field.Name { + case "statusCheckRollup": + prFeatures.HasStatusCheckRollup = true + } + } + return +} + func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, currentPRHeadRef, currentUsername string) (*PullRequestsPayload, error) { type edges struct { TotalCount int @@ -261,18 +307,20 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu ReviewRequested edges } - fragments := ` - fragment pr on PullRequest { - number - title - state - url - headRefName - headRepositoryOwner { - login - } - isCrossRepository - isDraft + cachedClient := makeCachedClient(client.http, time.Hour*24) + prFeatures, err := determinePullRequestFeatures(cachedClient, repo.RepoHost()) + if err != nil { + return nil, err + } + + var reviewsFragment string + if prFeatures.HasReviewDecision { + reviewsFragment = "reviewDecision" + } + + var statusesFragment string + if prFeatures.HasStatusCheckRollup { + statusesFragment = ` commits(last: 1) { nodes { commit { @@ -292,12 +340,28 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu } } } + ` + } + + fragments := fmt.Sprintf(` + fragment pr on PullRequest { + number + title + state + url + headRefName + headRepositoryOwner { + login + } + isCrossRepository + isDraft + %s } fragment prWithReviews on PullRequest { ...pr - reviewDecision + %s } - ` + `, statusesFragment, reviewsFragment) queryPrefix := ` query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) { @@ -345,6 +409,13 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu } ` + if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) { + currentUsername, err = CurrentLoginName(client, repo.RepoHost()) + if err != nil { + return nil, err + } + } + viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername) reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername) @@ -363,7 +434,7 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu } var resp response - err := client.GraphQL(repo.RepoHost(), query, variables, &resp) + err = client.GraphQL(repo.RepoHost(), query, variables, &resp) if err != nil { return nil, err } @@ -404,6 +475,45 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu return &payload, nil } +func prCommitsFragment(httpClient *http.Client, hostname string) (string, error) { + cachedClient := makeCachedClient(httpClient, time.Hour*24) + if prFeatures, err := determinePullRequestFeatures(cachedClient, hostname); err != nil { + return "", err + } else if !prFeatures.HasStatusCheckRollup { + return "", nil + } + + return ` + commits(last: 1) { + totalCount + nodes { + commit { + oid + statusCheckRollup { + contexts(last: 100) { + nodes { + ...on StatusContext { + context + state + targetUrl + } + ...on CheckRun { + name + status + conclusion + startedAt + completedAt + detailsUrl + } + } + } + } + } + } + } + `, nil +} + func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*PullRequest, error) { type response struct { Repository struct { @@ -411,6 +521,11 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu } } + statusesFragment, err := prCommitsFragment(client.http, repo.RepoHost()) + if err != nil { + return nil, err + } + query := ` query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) { repository(owner: $owner, name: $repo) { @@ -426,33 +541,7 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu author { login } - commits(last: 1) { - totalCount - nodes { - commit { - oid - statusCheckRollup { - contexts(last: 100) { - nodes { - ...on StatusContext { - context - state - targetUrl - } - ...on CheckRun { - name - status - conclusion - startedAt - completedAt - detailsUrl - } - } - } - } - } - } - } + ` + statusesFragment + ` baseRefName headRefName headRepositoryOwner { @@ -524,7 +613,7 @@ func PullRequestByNumber(client *Client, repo ghrepo.Interface, number int) (*Pu } var resp response - err := client.GraphQL(repo.RepoHost(), query, variables, &resp) + err = client.GraphQL(repo.RepoHost(), query, variables, &resp) if err != nil { return nil, err } @@ -542,6 +631,11 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea } } + statusesFragment, err := prCommitsFragment(client.http, repo.RepoHost()) + if err != nil { + return nil, err + } + query := ` query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!) { repository(owner: $owner, name: $repo) { @@ -556,33 +650,7 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea author { login } - commits(last: 1) { - totalCount - nodes { - commit { - oid - statusCheckRollup { - contexts(last: 100) { - nodes { - ...on StatusContext { - context - state - targetUrl - } - ...on CheckRun { - name - status - conclusion - startedAt - completedAt - detailsUrl - } - } - } - } - } - } - } + ` + statusesFragment + ` url baseRefName headRefName @@ -661,7 +729,7 @@ func PullRequestForBranch(client *Client, repo ghrepo.Interface, baseBranch, hea } var resp response - err := client.GraphQL(repo.RepoHost(), query, variables, &resp) + err = client.GraphQL(repo.RepoHost(), query, variables, &resp) if err != nil { return nil, err } diff --git a/api/queries_pr_test.go b/api/queries_pr_test.go index 3441534bf..e624b61b3 100644 --- a/api/queries_pr_test.go +++ b/api/queries_pr_test.go @@ -1,8 +1,10 @@ package api import ( + "reflect" "testing" + "github.com/MakeNowJust/heredoc" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/httpmock" ) @@ -45,3 +47,94 @@ func TestBranchDeleteRemote(t *testing.T) { }) } } + +func Test_determinePullRequestFeatures(t *testing.T) { + tests := []struct { + name string + hostname string + queryResponse string + wantPrFeatures pullRequestFeature + wantErr bool + }{ + { + name: "github.com", + hostname: "github.com", + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: true, + HasStatusCheckRollup: true, + }, + wantErr: false, + }, + { + name: "GHE empty response", + hostname: "git.my.org", + queryResponse: heredoc.Doc(` + {"data": {}} + `), + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: false, + HasStatusCheckRollup: false, + }, + wantErr: false, + }, + { + name: "GHE has reviewDecision", + hostname: "git.my.org", + queryResponse: heredoc.Doc(` + {"data": { + "PullRequest": { + "fields": [ + {"name": "foo"}, + {"name": "reviewDecision"} + ] + } + } } + `), + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: true, + HasStatusCheckRollup: false, + }, + wantErr: false, + }, + { + name: "GHE has statusCheckRollup", + hostname: "git.my.org", + queryResponse: heredoc.Doc(` + {"data": { + "Commit": { + "fields": [ + {"name": "foo"}, + {"name": "statusCheckRollup"} + ] + } + } } + `), + wantPrFeatures: pullRequestFeature{ + HasReviewDecision: false, + HasStatusCheckRollup: true, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeHTTP := &httpmock.Registry{} + httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP)) + + if tt.queryResponse != "" { + fakeHTTP.Register( + httpmock.GraphQL(`query PullRequest_fields\b`), + httpmock.StringResponse(tt.queryResponse)) + } + + gotPrFeatures, err := determinePullRequestFeatures(httpClient, tt.hostname) + if (err != nil) != tt.wantErr { + t.Errorf("determinePullRequestFeatures() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotPrFeatures, tt.wantPrFeatures) { + t.Errorf("determinePullRequestFeatures() = %v, want %v", gotPrFeatures, tt.wantPrFeatures) + } + }) + } +}