Merge 'trunk' into fix/issue_10042

This commit is contained in:
Kynan Ware 2025-01-16 11:03:17 -07:00
commit a2dce589d6
24 changed files with 1335 additions and 247 deletions

View file

@ -377,19 +377,36 @@ func (c *Client) lookupCommit(ctx context.Context, sha, format string) ([]byte,
}
// ReadBranchConfig parses the `branch.BRANCH.(remote|merge|gh-merge-base)` part of git config.
func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (cfg BranchConfig) {
// If no branch config is found or there is an error in the command, it returns an empty BranchConfig.
// Downstream consumers of ReadBranchConfig should consider the behavior they desire if this errors,
// as an empty config is not necessarily breaking.
func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (BranchConfig, error) {
prefix := regexp.QuoteMeta(fmt.Sprintf("branch.%s.", branch))
args := []string{"config", "--get-regexp", fmt.Sprintf("^%s(remote|merge|%s)$", prefix, MergeBaseConfig)}
cmd, err := c.Command(ctx, args...)
if err != nil {
return
}
out, err := cmd.Output()
if err != nil {
return
return BranchConfig{}, err
}
for _, line := range outputLines(out) {
out, err := cmd.Output()
if err != nil {
// This is the error we expect if the git command does not run successfully.
// If the ExitCode is 1, then we just didn't find any config for the branch.
var gitError *GitError
if ok := errors.As(err, &gitError); ok && gitError.ExitCode != 1 {
return BranchConfig{}, err
}
return BranchConfig{}, nil
}
return parseBranchConfig(outputLines(out)), nil
}
func parseBranchConfig(configLines []string) BranchConfig {
var cfg BranchConfig
for _, line := range configLines {
parts := strings.SplitN(line, " ", 2)
if len(parts) < 2 {
continue
@ -412,7 +429,7 @@ func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (cfg Branc
cfg.MergeBase = parts[1]
}
}
return
return cfg
}
// SetBranchConfig sets the named value on the given branch.

View file

@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
"os/exec"
"path/filepath"
@ -730,14 +731,35 @@ func TestClientReadBranchConfig(t *testing.T) {
cmdExitStatus int
cmdStdout string
cmdStderr string
wantCmdArgs string
branch string
wantBranchConfig BranchConfig
wantError *GitError
}{
{
name: "read branch config",
cmdExitStatus: 0,
cmdStdout: "branch.trunk.remote origin\nbranch.trunk.merge refs/heads/trunk\nbranch.trunk.gh-merge-base trunk",
wantCmdArgs: `path/to/git config --get-regexp ^branch\.trunk\.(remote|merge|gh-merge-base)$`,
branch: "trunk",
wantBranchConfig: BranchConfig{RemoteName: "origin", MergeRef: "refs/heads/trunk", MergeBase: "trunk"},
wantError: nil,
},
{
name: "git config runs successfully but returns no output (Exit Code 1)",
cmdExitStatus: 1,
cmdStdout: "",
cmdStderr: "",
branch: "trunk",
wantBranchConfig: BranchConfig{},
wantError: nil,
},
{
name: "output error (Exit Code > 1)",
cmdExitStatus: 2,
cmdStdout: "",
cmdStderr: "git error message",
branch: "trunk",
wantBranchConfig: BranchConfig{},
wantError: &GitError{},
},
}
for _, tt := range tests {
@ -747,9 +769,85 @@ func TestClientReadBranchConfig(t *testing.T) {
GitPath: "path/to/git",
commandContext: cmdCtx,
}
branchConfig := client.ReadBranchConfig(context.Background(), "trunk")
assert.Equal(t, tt.wantCmdArgs, strings.Join(cmd.Args[3:], " "))
branchConfig, err := client.ReadBranchConfig(context.Background(), tt.branch)
wantCmdArgs := fmt.Sprintf("path/to/git config --get-regexp ^branch\\.%s\\.(remote|merge|gh-merge-base)$", tt.branch)
assert.Equal(t, wantCmdArgs, strings.Join(cmd.Args[3:], " "))
assert.Equal(t, tt.wantBranchConfig, branchConfig)
if tt.wantError != nil {
assert.ErrorAs(t, err, &tt.wantError)
} else {
assert.NoError(t, err)
}
})
}
}
func Test_parseBranchConfig(t *testing.T) {
tests := []struct {
name string
configLines []string
wantBranchConfig BranchConfig
}{
{
name: "remote branch",
configLines: []string{"branch.trunk.remote origin"},
wantBranchConfig: BranchConfig{
RemoteName: "origin",
},
},
{
name: "merge ref",
configLines: []string{"branch.trunk.merge refs/heads/trunk"},
wantBranchConfig: BranchConfig{
MergeRef: "refs/heads/trunk",
},
},
{
name: "merge base",
configLines: []string{"branch.trunk.gh-merge-base gh-merge-base"},
wantBranchConfig: BranchConfig{
MergeBase: "gh-merge-base",
},
},
{
name: "remote, merge ref, and merge base all specified",
configLines: []string{
"branch.trunk.remote origin",
"branch.trunk.merge refs/heads/trunk",
"branch.trunk.gh-merge-base gh-merge-base",
},
wantBranchConfig: BranchConfig{
RemoteName: "origin",
MergeRef: "refs/heads/trunk",
MergeBase: "gh-merge-base",
},
},
{
name: "remote URL",
configLines: []string{
"branch.Frederick888/main.remote git@github.com:Frederick888/playground.git",
"branch.Frederick888/main.merge refs/heads/main",
},
wantBranchConfig: BranchConfig{
MergeRef: "refs/heads/main",
RemoteURL: &url.URL{
Scheme: "ssh",
User: url.User("git"),
Host: "github.com",
Path: "/Frederick888/playground.git",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
branchConfig := parseBranchConfig(tt.configLines)
assert.Equal(t, tt.wantBranchConfig.RemoteName, branchConfig.RemoteName)
assert.Equal(t, tt.wantBranchConfig.MergeRef, branchConfig.MergeRef)
assert.Equal(t, tt.wantBranchConfig.MergeBase, branchConfig.MergeBase)
if tt.wantBranchConfig.RemoteURL != nil {
assert.Equal(t, tt.wantBranchConfig.RemoteURL.String(), branchConfig.RemoteURL.String())
}
})
}
}

5
go.mod
View file

@ -21,6 +21,7 @@ require (
github.com/distribution/reference v0.5.0
github.com/gabriel-vasile/mimetype v1.4.7
github.com/gdamore/tcell/v2 v2.5.4
github.com/golang/snappy v0.0.4
github.com/google/go-cmp v0.6.0
github.com/google/go-containerregistry v0.20.2
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
@ -39,7 +40,7 @@ require (
github.com/opentracing/opentracing-go v1.2.0
github.com/rivo/tview v0.0.0-20221029100920-c4a7e501810d
github.com/shurcooL/githubv4 v0.0.0-20240120211514-18a1ae0e79dc
github.com/sigstore/protobuf-specs v0.3.2
github.com/sigstore/protobuf-specs v0.3.3
github.com/sigstore/sigstore-go v0.6.2
github.com/spf13/cobra v1.8.1
github.com/spf13/pflag v1.0.5
@ -50,7 +51,7 @@ require (
golang.org/x/term v0.27.0
golang.org/x/text v0.21.0
google.golang.org/grpc v1.64.1
google.golang.org/protobuf v1.34.2
google.golang.org/protobuf v1.36.2
gopkg.in/h2non/gock.v1 v1.1.2
gopkg.in/yaml.v3 v3.0.1
)

10
go.sum
View file

@ -198,6 +198,8 @@ github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/certificate-transparency-go v1.2.1 h1:4iW/NwzqOqYEEoCBEFP+jPbBXbLqMpq3CifMyOnDUME=
github.com/google/certificate-transparency-go v1.2.1/go.mod h1:bvn/ytAccv+I6+DGkqpvSsEdiVGramgaSC6RD3tEmeE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
@ -391,8 +393,8 @@ github.com/shurcooL/githubv4 v0.0.0-20240120211514-18a1ae0e79dc h1:vH0NQbIDk+mJL
github.com/shurcooL/githubv4 v0.0.0-20240120211514-18a1ae0e79dc/go.mod h1:zqMwyHmnN/eDOZOdiTohqIUKUrTFX62PNlu7IJdu0q8=
github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 h1:17JxqqJY66GmZVHkmAsGEkcIu0oCe3AM420QDgGwZx0=
github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466/go.mod h1:9dIRpgIY7hVhoqfe0/FcYp0bpInZaT7dc3BYOprrIUE=
github.com/sigstore/protobuf-specs v0.3.2 h1:nCVARCN+fHjlNCk3ThNXwrZRqIommIeNKWwQvORuRQo=
github.com/sigstore/protobuf-specs v0.3.2/go.mod h1:RZ0uOdJR4OB3tLQeAyWoJFbNCBFrPQdcokntde4zRBA=
github.com/sigstore/protobuf-specs v0.3.3 h1:RMZQgXTD/pF7KW6b5NaRLYxFYZ/wzx44PQFXN2PEo5g=
github.com/sigstore/protobuf-specs v0.3.3/go.mod h1:vIhZ6Uor1a38+wvRrKcqL2PtYNlgoIW9lhzYzkyy4EU=
github.com/sigstore/rekor v1.3.6 h1:QvpMMJVWAp69a3CHzdrLelqEqpTM3ByQRt5B5Kspbi8=
github.com/sigstore/rekor v1.3.6/go.mod h1:JDTSNNMdQ/PxdsS49DJkJ+pRJCO/83nbR5p3aZQteXc=
github.com/sigstore/sigstore v1.8.9 h1:NiUZIVWywgYuVTxXmRoTT4O4QAGiTEKup4N1wdxFadk=
@ -546,8 +548,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240520151616-dc85e6b867a5 h1:
google.golang.org/genproto/googleapis/rpc v0.0.0-20240520151616-dc85e6b867a5/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0=
google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA=
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
google.golang.org/protobuf v1.36.2 h1:R8FeyR1/eLmkutZOM5CWghmo5itiG9z0ktFlTVLuTmU=
google.golang.org/protobuf v1.36.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View file

@ -25,7 +25,8 @@ func newErrNoAttestations(name, digest string) ErrNoAttestations {
}
type Attestation struct {
Bundle *bundle.Bundle `json:"bundle"`
Bundle *bundle.Bundle `json:"bundle"`
BundleURL string `json:"bundle_url"`
}
type AttestationsResponse struct {

View file

@ -11,6 +11,11 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/cli/cli/v2/api"
ioconfig "github.com/cli/cli/v2/pkg/cmd/attestation/io"
"github.com/golang/snappy"
v1 "github.com/sigstore/protobuf-specs/gen/pb-go/bundle/v1"
"github.com/sigstore/sigstore-go/pkg/bundle"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/encoding/protojson"
)
const (
@ -19,11 +24,20 @@ const (
maxLimitForFetch = 100
)
type apiClient interface {
// Allow injecting backoff interval in tests.
var getAttestationRetryInterval = time.Millisecond * 200
// githubApiClient makes REST calls to the GitHub API
type githubApiClient interface {
REST(hostname, method, p string, body io.Reader, data interface{}) error
RESTWithNext(hostname, method, p string, body io.Reader, data interface{}) (string, error)
}
// httpClient makes HTTP calls to all non-GitHub API endpoints
type httpClient interface {
Get(url string) (*http.Response, error)
}
type Client interface {
GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error)
GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error)
@ -31,16 +45,18 @@ type Client interface {
}
type LiveClient struct {
api apiClient
host string
logger *ioconfig.Handler
githubAPI githubApiClient
httpClient httpClient
host string
logger *ioconfig.Handler
}
func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClient {
return &LiveClient{
api: api.NewClientFromHTTP(hc),
host: strings.TrimSuffix(host, "/"),
logger: l,
githubAPI: api.NewClientFromHTTP(hc),
host: strings.TrimSuffix(host, "/"),
httpClient: hc,
logger: l,
}
}
@ -52,7 +68,17 @@ func (c *LiveClient) BuildRepoAndDigestURL(repo, digest string) string {
// GetByRepoAndDigest fetches the attestation by repo and digest
func (c *LiveClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) {
url := c.BuildRepoAndDigestURL(repo, digest)
return c.getAttestations(url, repo, digest, limit)
attestations, err := c.getAttestations(url, repo, digest, limit)
if err != nil {
return nil, err
}
bundles, err := c.fetchBundleFromAttestations(attestations)
if err != nil {
return nil, fmt.Errorf("failed to fetch bundle with URL: %w", err)
}
return bundles, nil
}
func (c *LiveClient) BuildOwnerAndDigestURL(owner, digest string) string {
@ -63,7 +89,21 @@ func (c *LiveClient) BuildOwnerAndDigestURL(owner, digest string) string {
// GetByOwnerAndDigest fetches attestation by owner and digest
func (c *LiveClient) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) {
url := c.BuildOwnerAndDigestURL(owner, digest)
return c.getAttestations(url, owner, digest, limit)
attestations, err := c.getAttestations(url, owner, digest, limit)
if err != nil {
return nil, err
}
if len(attestations) == 0 {
return nil, newErrNoAttestations(owner, digest)
}
bundles, err := c.fetchBundleFromAttestations(attestations)
if err != nil {
return nil, fmt.Errorf("failed to fetch bundle with URL: %w", err)
}
return bundles, nil
}
// GetTrustDomain returns the current trust domain. If the default is used
@ -72,9 +112,6 @@ func (c *LiveClient) GetTrustDomain() (string, error) {
return c.getTrustDomain(MetaPath)
}
// Allow injecting backoff interval in tests.
var getAttestationRetryInterval = time.Millisecond * 200
func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*Attestation, error) {
c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest)
@ -97,7 +134,7 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At
// if no attestation or less than limit, then keep fetching
for url != "" && len(attestations) < limit {
err := backoff.Retry(func() error {
newURL, restErr := c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
newURL, restErr := c.githubAPI.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
if restErr != nil {
if shouldRetry(restErr) {
@ -130,6 +167,77 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At
return attestations, nil
}
func (c *LiveClient) fetchBundleFromAttestations(attestations []*Attestation) ([]*Attestation, error) {
fetched := make([]*Attestation, len(attestations))
g := errgroup.Group{}
for i, a := range attestations {
g.Go(func() error {
if a.Bundle == nil && a.BundleURL == "" {
return fmt.Errorf("attestation has no bundle or bundle URL")
}
// for now, we fallback to the bundle field if the bundle URL is empty
if a.BundleURL == "" {
c.logger.VerbosePrintf("Bundle URL is empty. Falling back to bundle field\n\n")
fetched[i] = &Attestation{
Bundle: a.Bundle,
}
return nil
}
// otherwise fetch the bundle with the provided URL
b, err := c.GetBundle(a.BundleURL)
if err != nil {
return fmt.Errorf("failed to fetch bundle with URL: %w", err)
}
fetched[i] = &Attestation{
Bundle: b,
}
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return fetched, nil
}
func (c *LiveClient) GetBundle(url string) (*bundle.Bundle, error) {
c.logger.VerbosePrintf("Fetching attestation bundle with bundle URL\n\n")
resp, err := c.httpClient.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read blob storage response body: %w", err)
}
var out []byte
decompressed, err := snappy.Decode(out, body)
if err != nil {
return nil, fmt.Errorf("failed to decompress with snappy: %w", err)
}
var pbBundle v1.Bundle
if err = protojson.Unmarshal(decompressed, &pbBundle); err != nil {
return nil, fmt.Errorf("failed to unmarshal to bundle: %w", err)
}
c.logger.VerbosePrintf("Successfully fetched bundle\n\n")
return bundle.NewBundle(&pbBundle)
}
func shouldRetry(err error) bool {
var httpError api.HTTPError
if errors.As(err, &httpError) {
@ -146,7 +254,7 @@ func (c *LiveClient) getTrustDomain(url string) (string, error) {
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)
err := backoff.Retry(func() error {
restErr := c.api.REST(c.host, http.MethodGet, url, nil, &resp)
restErr := c.githubAPI.REST(c.host, http.MethodGet, url, nil, &resp)
if restErr != nil {
if shouldRetry(restErr) {
return restErr

View file

@ -4,6 +4,7 @@ import (
"testing"
"github.com/cli/cli/v2/pkg/cmd/attestation/io"
"github.com/cli/cli/v2/pkg/cmd/attestation/test/data"
"github.com/stretchr/testify/require"
)
@ -20,20 +21,24 @@ func NewClientWithMockGHClient(hasNextPage bool) Client {
}
l := io.NewTestHandler()
httpClient := &mockHttpClient{}
if hasNextPage {
return &LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTSuccessWithNextPage,
},
logger: l,
httpClient: httpClient,
logger: l,
}
}
return &LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTSuccess,
},
logger: l,
httpClient: httpClient,
logger: l,
}
}
@ -134,11 +139,13 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) {
NumAttestations: 5,
}
httpClient := &mockHttpClient{}
c := LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTWithNextNoAttestations,
},
logger: io.NewTestHandler(),
httpClient: httpClient,
logger: io.NewTestHandler(),
}
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
@ -158,7 +165,7 @@ func TestGetByDigest_Error(t *testing.T) {
}
c := LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTWithNextError,
},
logger: io.NewTestHandler(),
@ -173,6 +180,86 @@ func TestGetByDigest_Error(t *testing.T) {
require.Nil(t, attestations)
}
func TestFetchBundleFromAttestations(t *testing.T) {
httpClient := &mockHttpClient{}
client := LiveClient{
httpClient: httpClient,
logger: io.NewTestHandler(),
}
att1 := makeTestAttestation()
att2 := makeTestAttestation()
attestations := []*Attestation{&att1, &att2}
fetched, err := client.fetchBundleFromAttestations(attestations)
require.NoError(t, err)
require.Len(t, fetched, 2)
require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", fetched[0].Bundle.GetMediaType())
httpClient.AssertNumberOfCalls(t, "OnGetSuccess", 2)
}
func TestFetchBundleFromAttestations_InvalidAttestation(t *testing.T) {
httpClient := &mockHttpClient{}
client := LiveClient{
httpClient: httpClient,
logger: io.NewTestHandler(),
}
att1 := Attestation{}
attestations := []*Attestation{&att1}
fetched, err := client.fetchBundleFromAttestations(attestations)
require.Error(t, err)
require.Nil(t, fetched, 2)
}
func TestFetchBundleFromAttestations_Fail(t *testing.T) {
httpClient := &failAfterOneCallHttpClient{}
c := &LiveClient{
httpClient: httpClient,
logger: io.NewTestHandler(),
}
att1 := makeTestAttestation()
att2 := makeTestAttestation()
attestations := []*Attestation{&att1, &att2}
fetched, err := c.fetchBundleFromAttestations(attestations)
require.Error(t, err)
require.Nil(t, fetched)
httpClient.AssertNumberOfCalls(t, "OnGetFailAfterOneCall", 2)
}
func TestFetchBundleFromAttestations_FetchByURLFail(t *testing.T) {
mockHTTPClient := &failHttpClient{}
c := &LiveClient{
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}
a := makeTestAttestation()
attestations := []*Attestation{&a}
bundle, err := c.fetchBundleFromAttestations(attestations)
require.Error(t, err)
require.Nil(t, bundle)
mockHTTPClient.AssertNumberOfCalls(t, "OnGetFail", 1)
}
func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) {
mockHTTPClient := &mockHttpClient{}
c := &LiveClient{
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}
a := Attestation{Bundle: data.SigstoreBundle(t)}
attestations := []*Attestation{&a}
fetched, err := c.fetchBundleFromAttestations(attestations)
require.NoError(t, err)
require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", fetched[0].Bundle.GetMediaType())
mockHTTPClient.AssertNotCalled(t, "OnGetSuccess")
}
func TestGetTrustDomain(t *testing.T) {
fetcher := mockMetaGenerator{
TrustDomain: "foo",
@ -180,7 +267,7 @@ func TestGetTrustDomain(t *testing.T) {
t.Run("with returned trust domain", func(t *testing.T) {
c := LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnREST: fetcher.OnREST,
},
logger: io.NewTestHandler(),
@ -193,7 +280,7 @@ func TestGetTrustDomain(t *testing.T) {
t.Run("with error", func(t *testing.T) {
c := LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnREST: fetcher.OnRESTError,
},
logger: io.NewTestHandler(),
@ -213,10 +300,11 @@ func TestGetAttestationsRetries(t *testing.T) {
}
c := &LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(),
},
logger: io.NewTestHandler(),
httpClient: &mockHttpClient{},
logger: io.NewTestHandler(),
}
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
@ -252,7 +340,7 @@ func TestGetAttestationsMaxRetries(t *testing.T) {
}
c := &LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnREST500ErrorHandler(),
},
logger: io.NewTestHandler(),

View file

@ -25,7 +25,7 @@ func (m MockClient) GetTrustDomain() (string, error) {
}
func makeTestAttestation() Attestation {
return Attestation{Bundle: data.SigstoreBundle(nil)}
return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"}
}
func OnGetByRepoAndDigestSuccess(repo, digest string, limit int) ([]*Attestation, error) {

View file

@ -0,0 +1,64 @@
package api
import (
"bytes"
"fmt"
"io"
"net/http"
"github.com/cli/cli/v2/pkg/cmd/attestation/test/data"
"github.com/golang/snappy"
"github.com/stretchr/testify/mock"
)
type mockHttpClient struct {
mock.Mock
}
func (m *mockHttpClient) Get(url string) (*http.Response, error) {
m.On("OnGetSuccess").Return()
m.MethodCalled("OnGetSuccess")
var compressed []byte
compressed = snappy.Encode(compressed, data.SigstoreBundleRaw)
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(compressed)),
}, nil
}
type failHttpClient struct {
mock.Mock
}
func (m *failHttpClient) Get(url string) (*http.Response, error) {
m.On("OnGetFail").Return()
m.MethodCalled("OnGetFail")
return &http.Response{
StatusCode: 500,
}, fmt.Errorf("failed to fetch with %s", url)
}
type failAfterOneCallHttpClient struct {
mock.Mock
}
func (m *failAfterOneCallHttpClient) Get(url string) (*http.Response, error) {
m.On("OnGetFailAfterOneCall").Return()
if len(m.Calls) >= 1 {
m.MethodCalled("OnGetFailAfterOneCall")
return &http.Response{
StatusCode: 500,
}, fmt.Errorf("failed to fetch with %s", url)
}
m.MethodCalled("OnGetFailAfterOneCall")
var compressed []byte
compressed = snappy.Encode(compressed, data.SigstoreBundleRaw)
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(compressed)),
}, nil
}

View file

@ -6,8 +6,10 @@ import (
"net/http"
"strings"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/pkg/cmd/gist/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
@ -18,8 +20,10 @@ type DeleteOptions struct {
IO *iostreams.IOStreams
Config func() (gh.Config, error)
HttpClient func() (*http.Client, error)
Prompter prompter.Prompter
Selector string
Selector string
Confirmed bool
}
func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Command {
@ -27,33 +31,51 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co
IO: f.IOStreams,
Config: f.Config,
HttpClient: f.HttpClient,
Prompter: f.Prompter,
}
cmd := &cobra.Command{
Use: "delete {<id> | <url>}",
Short: "Delete a gist",
Args: cmdutil.ExactArgs(1, "cannot delete: gist argument required"),
Long: heredoc.Docf(`
Delete a GitHub gist.
To delete a gist interactively, use %[1]sgh gist delete%[1]s with no arguments.
To delete a gist non-interactively, supply the gist id or url.
`, "`"),
Example: heredoc.Doc(`
# delete a gist interactively
gh gist delete
# delete a gist non-interactively
gh gist delete 1234
`),
Args: cobra.MaximumNArgs(1),
RunE: func(c *cobra.Command, args []string) error {
opts.Selector = args[0]
if !opts.IO.CanPrompt() && !opts.Confirmed {
return cmdutil.FlagErrorf("--yes required when not running interactively")
}
if !opts.IO.CanPrompt() && len(args) == 0 {
return cmdutil.FlagErrorf("id or url argument required in non-interactive mode")
}
if len(args) == 1 {
opts.Selector = args[0]
}
if runF != nil {
return runF(&opts)
}
return deleteRun(&opts)
},
}
cmd.Flags().BoolVar(&opts.Confirmed, "yes", false, "confirm deletion without prompting")
return cmd
}
func deleteRun(opts *DeleteOptions) error {
gistID := opts.Selector
if strings.Contains(gistID, "/") {
id, err := shared.GistIDFromURL(gistID)
if err != nil {
return err
}
gistID = id
}
client, err := opts.HttpClient()
if err != nil {
return err
@ -66,14 +88,56 @@ func deleteRun(opts *DeleteOptions) error {
host, _ := cfg.Authentication().DefaultHost()
gistID := opts.Selector
if strings.Contains(gistID, "/") {
id, err := shared.GistIDFromURL(gistID)
if err != nil {
return err
}
gistID = id
}
cs := opts.IO.ColorScheme()
var gist *shared.Gist
if gistID == "" {
gist, err = shared.PromptGists(opts.Prompter, client, host, cs)
} else {
gist, err = shared.GetGist(client, host, gistID)
}
if err != nil {
return err
}
if gist.ID == "" {
fmt.Fprintln(opts.IO.Out, "No gists found.")
return nil
}
if !opts.Confirmed {
confirmed, err := opts.Prompter.Confirm(fmt.Sprintf("Delete %q gist?", gist.Filename()), false)
if err != nil {
return err
}
if !confirmed {
return cmdutil.CancelError
}
}
apiClient := api.NewClientFromHTTP(client)
if err := deleteGist(apiClient, host, gistID); err != nil {
if err := deleteGist(apiClient, host, gist.ID); err != nil {
if errors.Is(err, shared.NotFoundErr) {
return fmt.Errorf("unable to delete gist %s: either the gist is not found or it is not owned by you", gistID)
return fmt.Errorf("unable to delete gist %q: either the gist is not found or it is not owned by you", gist.Filename())
}
return err
}
if opts.IO.IsStdoutTTY() {
cs := opts.IO.ColorScheme()
fmt.Fprintf(opts.IO.Out,
"%s Gist %q deleted\n",
cs.SuccessIcon(),
gist.Filename())
}
return nil
}

View file

@ -2,36 +2,95 @@ package delete
import (
"bytes"
"fmt"
"net/http"
"testing"
"time"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/pkg/cmd/gist/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/cli/cli/v2/pkg/iostreams"
ghAPI "github.com/cli/go-gh/v2/pkg/api"
"github.com/google/shlex"
"github.com/stretchr/testify/assert"
)
func TestNewCmdDelete(t *testing.T) {
tests := []struct {
name string
cli string
wants DeleteOptions
name string
cli string
tty bool
want DeleteOptions
wantErr bool
wantErrMsg string
}{
{
name: "valid selector",
cli: "123",
wants: DeleteOptions{
tty: true,
want: DeleteOptions{
Selector: "123",
},
},
{
name: "valid selector, no ID supplied",
cli: "",
tty: true,
want: DeleteOptions{
Selector: "",
},
},
{
name: "no ID supplied with --yes",
cli: "--yes",
tty: true,
want: DeleteOptions{
Selector: "",
},
},
{
name: "selector with --yes, no tty",
cli: "123 --yes",
tty: false,
want: DeleteOptions{
Selector: "123",
},
},
{
name: "ID arg without --yes, no tty",
cli: "123",
tty: false,
want: DeleteOptions{
Selector: "",
},
wantErr: true,
wantErrMsg: "--yes required when not running interactively",
},
{
name: "no ID supplied with --yes, no tty",
cli: "--yes",
tty: false,
want: DeleteOptions{
Selector: "",
},
wantErr: true,
wantErrMsg: "id or url argument required in non-interactive mode",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &cmdutil.Factory{}
io, _, _, _ := iostreams.Test()
f := &cmdutil.Factory{
IOStreams: io,
}
io.SetStdinTTY(tt.tty)
io.SetStdoutTTY(tt.tty)
argv, err := shlex.Split(tt.cli)
assert.NoError(t, err)
@ -47,69 +106,210 @@ func TestNewCmdDelete(t *testing.T) {
cmd.SetErr(&bytes.Buffer{})
_, err = cmd.ExecuteC()
assert.NoError(t, err)
if tt.wantErr {
assert.EqualError(t, err, tt.wantErrMsg)
return
}
assert.Equal(t, tt.wants.Selector, gotOpts.Selector)
assert.NoError(t, err)
assert.Equal(t, tt.want.Selector, gotOpts.Selector)
})
}
}
func Test_deleteRun(t *testing.T) {
tests := []struct {
name string
opts DeleteOptions
httpStubs func(*httpmock.Registry)
wantErr bool
wantStdout string
wantStderr string
name string
opts *DeleteOptions
cancel bool
httpStubs func(*httpmock.Registry)
mockPromptGists bool
noGists bool
wantErr bool
wantStdout string
wantStderr string
}{
{
name: "successfully delete",
opts: DeleteOptions{
opts: &DeleteOptions{
Selector: "1234",
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "gists/1234"),
httpmock.JSONResponse(shared.Gist{ID: "1234", Files: map[string]*shared.GistFile{"cool.txt": {Filename: "cool.txt"}}}))
reg.Register(httpmock.REST("DELETE", "gists/1234"),
httpmock.StatusStringResponse(200, "{}"))
},
wantStdout: "✓ Gist \"cool.txt\" deleted\n",
},
{
name: "successfully delete with prompt",
opts: &DeleteOptions{
Selector: "",
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("DELETE", "gists/1234"),
httpmock.StatusStringResponse(200, "{}"))
},
wantErr: false,
wantStdout: "",
wantStderr: "",
mockPromptGists: true,
wantStdout: "✓ Gist \"cool.txt\" deleted\n",
},
{
name: "not found",
opts: DeleteOptions{
name: "successfully delete with --yes",
opts: &DeleteOptions{
Selector: "1234",
Confirmed: true,
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "gists/1234"),
httpmock.JSONResponse(shared.Gist{ID: "1234", Files: map[string]*shared.GistFile{"cool.txt": {Filename: "cool.txt"}}}))
reg.Register(httpmock.REST("DELETE", "gists/1234"),
httpmock.StatusStringResponse(200, "{}"))
},
wantStdout: "✓ Gist \"cool.txt\" deleted\n",
},
{
name: "successfully delete with prompt and --yes",
opts: &DeleteOptions{
Selector: "",
Confirmed: true,
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("DELETE", "gists/1234"),
httpmock.StatusStringResponse(200, "{}"))
},
mockPromptGists: true,
wantStdout: "✓ Gist \"cool.txt\" deleted\n",
},
{
name: "cancel delete with id",
opts: &DeleteOptions{
Selector: "1234",
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "gists/1234"),
httpmock.JSONResponse(shared.Gist{ID: "1234", Files: map[string]*shared.GistFile{"cool.txt": {Filename: "cool.txt"}}}))
},
cancel: true,
wantErr: true,
},
{
name: "cancel delete with url",
opts: &DeleteOptions{
Selector: "https://gist.github.com/myrepo/1234",
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "gists/1234"),
httpmock.JSONResponse(shared.Gist{ID: "1234", Files: map[string]*shared.GistFile{"cool.txt": {Filename: "cool.txt"}}}))
},
cancel: true,
wantErr: true,
},
{
name: "cancel delete with prompt",
opts: &DeleteOptions{
Selector: "",
},
httpStubs: func(reg *httpmock.Registry) {},
mockPromptGists: true,
cancel: true,
wantErr: true,
},
{
name: "not owned by you",
opts: &DeleteOptions{
Selector: "1234",
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "gists/1234"),
httpmock.JSONResponse(shared.Gist{ID: "1234", Files: map[string]*shared.GistFile{"cool.txt": {Filename: "cool.txt"}}}))
reg.Register(httpmock.REST("DELETE", "gists/1234"),
httpmock.StatusStringResponse(404, "{}"))
},
wantErr: true,
wantStdout: "",
wantStderr: "",
wantStderr: "unable to delete gist \"cool.txt\": either the gist is not found or it is not owned by you",
},
{
name: "not found",
opts: &DeleteOptions{
Selector: "1234",
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(httpmock.REST("GET", "gists/1234"),
httpmock.StatusStringResponse(404, "{}"))
},
wantErr: true,
wantStderr: "not found",
},
{
name: "no gists",
opts: &DeleteOptions{
Selector: "",
},
httpStubs: func(reg *httpmock.Registry) {},
mockPromptGists: true,
noGists: true,
wantStdout: "No gists found.\n",
},
}
for _, tt := range tests {
reg := &httpmock.Registry{}
if tt.httpStubs != nil {
tt.httpStubs(reg)
pm := prompter.NewMockPrompter(t)
if !tt.opts.Confirmed {
pm.RegisterConfirm("Delete \"cool.txt\" gist?", func(_ string, _ bool) (bool, error) {
return !tt.cancel, nil
})
}
reg := &httpmock.Registry{}
tt.httpStubs(reg)
if tt.mockPromptGists {
if tt.noGists {
reg.Register(
httpmock.GraphQL(`query GistList\b`),
httpmock.StringResponse(
`{ "data": { "viewer": { "gists": { "nodes": [] }} } }`),
)
} else {
sixHours, _ := time.ParseDuration("6h")
sixHoursAgo := time.Now().Add(-sixHours)
reg.Register(
httpmock.GraphQL(`query GistList\b`),
httpmock.StringResponse(fmt.Sprintf(
`{ "data": { "viewer": { "gists": { "nodes": [
{
"name": "1234",
"files": [{ "name": "cool.txt" }],
"updatedAt": "%s",
"isPublic": true
}
] } } } }`,
sixHoursAgo.Format(time.RFC3339),
)),
)
pm.RegisterSelect("Select a gist", []string{"cool.txt about 6 hours ago"}, func(_, _ string, _ []string) (int, error) {
return 0, nil
})
}
}
tt.opts.Prompter = pm
tt.opts.HttpClient = func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
}
tt.opts.Config = func() (gh.Config, error) {
return config.NewBlankConfig(), nil
}
ios, _, _, _ := iostreams.Test()
ios.SetStdoutTTY(false)
ios, _, stdout, stderr := iostreams.Test()
ios.SetStdoutTTY(true)
ios.SetStdinTTY(false)
tt.opts.IO = ios
t.Run(tt.name, func(t *testing.T) {
err := deleteRun(&tt.opts)
err := deleteRun(tt.opts)
reg.Verify(t)
if tt.wantErr {
assert.Error(t, err)
@ -119,6 +319,75 @@ func Test_deleteRun(t *testing.T) {
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantStdout, stdout.String())
assert.Equal(t, tt.wantStderr, stderr.String())
})
}
}
func Test_gistDelete(t *testing.T) {
tests := []struct {
name string
httpStubs func(*httpmock.Registry)
hostname string
gistID string
wantErr error
}{
{
name: "successful delete",
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.REST("DELETE", "gists/1234"),
httpmock.StatusStringResponse(204, "{}"),
)
},
hostname: "github.com",
gistID: "1234",
wantErr: nil,
},
{
name: "when an gist is not found, it returns a NotFoundError",
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.REST("DELETE", "gists/1234"),
httpmock.StatusStringResponse(404, "{}"),
)
},
hostname: "github.com",
gistID: "1234",
wantErr: shared.NotFoundErr,
},
{
name: "when there is a non-404 error deleting the gist, that error is returned",
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.REST("DELETE", "gists/1234"),
httpmock.StatusJSONResponse(500, `{"message": "arbitrary error"}`),
)
},
hostname: "github.com",
gistID: "1234",
wantErr: api.HTTPError{
HTTPError: &ghAPI.HTTPError{
StatusCode: 500,
Message: "arbitrary error",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reg := &httpmock.Registry{}
tt.httpStubs(reg)
client := api.NewClientFromHTTP(&http.Client{Transport: reg})
err := deleteGist(client, tt.hostname, tt.gistID)
if tt.wantErr != nil {
assert.ErrorAs(t, err, &tt.wantErr)
} else {
assert.NoError(t, err)
}
})
}

View file

@ -112,15 +112,17 @@ func editRun(opts *EditOptions) error {
return cmdutil.FlagErrorf("gist ID or URL required when not running interactively")
}
gistID, err = shared.PromptGists(opts.Prompter, client, host, cs)
gist, err := shared.PromptGists(opts.Prompter, client, host, cs)
if err != nil {
return err
}
if gistID == "" {
if gist.ID == "" {
fmt.Fprintln(opts.IO.Out, "No gists found.")
return nil
}
gistID = gist.ID
}
}

View file

@ -39,6 +39,19 @@ type Gist struct {
Owner *GistOwner `json:"owner,omitempty"`
}
func (g Gist) Filename() string {
filenames := make([]string, 0, len(g.Files))
for fn := range g.Files {
filenames = append(filenames, fn)
}
sort.Strings(filenames)
return filenames[0]
}
func (g Gist) TruncDescription() string {
return text.Truncate(100, text.RemoveExcessiveWhitespace(g.Description))
}
var NotFoundErr = errors.New("not found")
func GetGist(client *http.Client, hostname, gistID string) (*Gist, error) {
@ -202,47 +215,29 @@ func IsBinaryContents(contents []byte) bool {
return isBinary
}
func PromptGists(prompter prompter.Prompter, client *http.Client, host string, cs *iostreams.ColorScheme) (gistID string, err error) {
func PromptGists(prompter prompter.Prompter, client *http.Client, host string, cs *iostreams.ColorScheme) (gist *Gist, err error) {
gists, err := ListGists(client, host, 10, nil, false, "all")
if err != nil {
return "", err
return &Gist{}, err
}
if len(gists) == 0 {
return "", nil
return &Gist{}, nil
}
var opts []string
var gistIDs = make([]string, len(gists))
var opts = make([]string, len(gists))
for i, gist := range gists {
gistIDs[i] = gist.ID
description := ""
gistName := ""
if gist.Description != "" {
description = gist.Description
}
filenames := make([]string, 0, len(gist.Files))
for fn := range gist.Files {
filenames = append(filenames, fn)
}
sort.Strings(filenames)
gistName = filenames[0]
gistTime := text.FuzzyAgo(time.Now(), gist.UpdatedAt)
// TODO: support dynamic maxWidth
description = text.Truncate(100, text.RemoveExcessiveWhitespace(description))
opt := fmt.Sprintf("%s %s %s", cs.Bold(gistName), description, cs.Gray(gistTime))
opts = append(opts, opt)
opts[i] = fmt.Sprintf("%s %s %s", cs.Bold(gist.Filename()), gist.TruncDescription(), cs.Gray(gistTime))
}
result, err := prompter.Select("Select a gist", "", opts)
if err != nil {
return "", err
return &Gist{}, err
}
return gistIDs[result], nil
return &gists[result], nil
}

View file

@ -93,12 +93,15 @@ func TestIsBinaryContents(t *testing.T) {
}
func TestPromptGists(t *testing.T) {
sixHours, _ := time.ParseDuration("6h")
sixHoursAgo := time.Now().Add(-sixHours)
sixHoursAgoFormatted := sixHoursAgo.Format(time.RFC3339Nano)
tests := []struct {
name string
prompterStubs func(pm *prompter.MockPrompter)
response string
wantOut string
gist *Gist
wantOut Gist
wantErr bool
}{
{
@ -112,21 +115,21 @@ func TestPromptGists(t *testing.T) {
},
response: `{ "data": { "viewer": { "gists": { "nodes": [
{
"name": "gistid1",
"name": "1234",
"files": [{ "name": "cool.txt" }],
"description": "",
"updatedAt": "%[1]v",
"isPublic": true
},
{
"name": "gistid2",
"name": "5678",
"files": [{ "name": "gistfile0.txt" }],
"description": "",
"updatedAt": "%[1]v",
"isPublic": true
}
] } } } }`,
wantOut: "gistid1",
wantOut: Gist{ID: "1234", Files: map[string]*GistFile{"cool.txt": {Filename: "cool.txt"}}, UpdatedAt: sixHoursAgo, Public: true},
},
{
name: "multiple files, select second gist",
@ -139,26 +142,26 @@ func TestPromptGists(t *testing.T) {
},
response: `{ "data": { "viewer": { "gists": { "nodes": [
{
"name": "gistid1",
"name": "1234",
"files": [{ "name": "cool.txt" }],
"description": "",
"updatedAt": "%[1]v",
"isPublic": true
},
{
"name": "gistid2",
"name": "5678",
"files": [{ "name": "gistfile0.txt" }],
"description": "",
"updatedAt": "%[1]v",
"isPublic": true
}
] } } } }`,
wantOut: "gistid2",
wantOut: Gist{ID: "5678", Files: map[string]*GistFile{"gistfile0.txt": {Filename: "gistfile0.txt"}}, UpdatedAt: sixHoursAgo, Public: true},
},
{
name: "no files",
response: `{ "data": { "viewer": { "gists": { "nodes": [] } } } }`,
wantOut: "",
wantOut: Gist{},
},
}
@ -166,15 +169,12 @@ func TestPromptGists(t *testing.T) {
for _, tt := range tests {
reg := &httpmock.Registry{}
const query = `query GistList\b`
sixHours, _ := time.ParseDuration("6h")
sixHoursAgo := time.Now().Add(-sixHours)
reg.Register(
httpmock.GraphQL(query),
httpmock.StringResponse(fmt.Sprintf(
tt.response,
sixHoursAgo.Format(time.RFC3339),
sixHoursAgoFormatted,
)),
)
client := &http.Client{Transport: reg}
@ -185,9 +185,9 @@ func TestPromptGists(t *testing.T) {
tt.prompterStubs(mockPrompter)
}
gistID, err := PromptGists(mockPrompter, client, "github.com", ios.ColorScheme())
gist, err := PromptGists(mockPrompter, client, "github.com", ios.ColorScheme())
assert.NoError(t, err)
assert.Equal(t, tt.wantOut, gistID)
assert.Equal(t, tt.wantOut.ID, gist.ID)
reg.Verify(t)
})
}

View file

@ -93,15 +93,16 @@ func viewRun(opts *ViewOptions) error {
return cmdutil.FlagErrorf("gist ID or URL required when not running interactively")
}
gistID, err = shared.PromptGists(opts.Prompter, client, hostname, cs)
gist, err := shared.PromptGists(opts.Prompter, client, hostname, cs)
if err != nil {
return err
}
if gistID == "" {
if gist.ID == "" {
fmt.Fprintln(opts.IO.Out, "No gists found.")
return nil
}
gistID = gist.ID
}
if opts.Web {

View file

@ -263,17 +263,17 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
return cmd
}
func createRun(opts *CreateOptions) (err error) {
func createRun(opts *CreateOptions) error {
ctx, err := NewCreateContext(opts)
if err != nil {
return
return err
}
client := ctx.Client
state, err := NewIssueState(*ctx, *opts)
if err != nil {
return
return err
}
var openURL string
@ -288,15 +288,15 @@ func createRun(opts *CreateOptions) (err error) {
}
err = handlePush(*opts, *ctx)
if err != nil {
return
return err
}
openURL, err = generateCompareURL(*ctx, *state)
if err != nil {
return
return err
}
if !shared.ValidURL(openURL) {
err = fmt.Errorf("cannot open in browser: maximum URL length exceeded")
return
return err
}
return previewPR(*opts, openURL)
}
@ -344,7 +344,7 @@ func createRun(opts *CreateOptions) (err error) {
if !opts.EditorMode && (opts.FillVerbose || opts.Autofill || opts.FillFirst || (opts.TitleProvided && opts.BodyProvided)) {
err = handlePush(*opts, *ctx)
if err != nil {
return
return err
}
return submitPR(*opts, *ctx, *state)
}
@ -368,7 +368,7 @@ func createRun(opts *CreateOptions) (err error) {
var template shared.Template
template, err = tpl.Select(opts.Template)
if err != nil {
return
return err
}
if state.Title == "" {
state.Title = template.Title()
@ -378,18 +378,18 @@ func createRun(opts *CreateOptions) (err error) {
state.Title, state.Body, err = opts.TitledEditSurvey(state.Title, state.Body)
if err != nil {
return
return err
}
if state.Title == "" {
err = fmt.Errorf("title can't be blank")
return
return err
}
} else {
if !opts.TitleProvided {
err = shared.TitleSurvey(opts.Prompter, opts.IO, state)
if err != nil {
return
return err
}
}
@ -403,12 +403,12 @@ func createRun(opts *CreateOptions) (err error) {
if opts.Template != "" {
template, err = tpl.Select(opts.Template)
if err != nil {
return
return err
}
} else {
template, err = tpl.Choose()
if err != nil {
return
return err
}
}
@ -419,13 +419,13 @@ func createRun(opts *CreateOptions) (err error) {
err = shared.BodySurvey(opts.Prompter, state, templateContent)
if err != nil {
return
return err
}
}
openURL, err = generateCompareURL(*ctx, *state)
if err != nil {
return
return err
}
allowPreview := !state.HasMetadata() && shared.ValidURL(openURL) && !opts.DryRun
@ -444,12 +444,12 @@ func createRun(opts *CreateOptions) (err error) {
}
err = shared.MetadataSurvey(opts.Prompter, opts.IO, ctx.BaseRepo, fetcher, state)
if err != nil {
return
return err
}
action, err = shared.ConfirmPRSubmission(opts.Prompter, !state.HasMetadata() && !opts.DryRun, false, state.Draft)
if err != nil {
return
return err
}
}
}
@ -457,12 +457,12 @@ func createRun(opts *CreateOptions) (err error) {
if action == shared.CancelAction {
fmt.Fprintln(opts.IO.ErrOut, "Discarding.")
err = cmdutil.CancelError
return
return err
}
err = handlePush(*opts, *ctx)
if err != nil {
return
return err
}
if action == shared.PreviewAction {
@ -479,7 +479,7 @@ func createRun(opts *CreateOptions) (err error) {
}
err = errors.New("expected to cancel, preview, or submit")
return
return err
}
var regexPattern = regexp.MustCompile(`(?m)^`)
@ -680,7 +680,10 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) {
var headRepo ghrepo.Interface
var headRemote *ghContext.Remote
headBranchConfig := gitClient.ReadBranchConfig(context.Background(), headBranch)
headBranchConfig, err := gitClient.ReadBranchConfig(context.Background(), headBranch)
if err != nil {
return nil, err
}
if isPushEnabled {
// determine whether the head branch is already pushed to a remote
if trackingRef, found := tryDetermineTrackingRef(gitClient, remotes, headBranch, headBranchConfig); found {

View file

@ -1,7 +1,6 @@
package create
import (
ctx "context"
"encoding/json"
"errors"
"fmt"
@ -1626,6 +1625,7 @@ func Test_tryDetermineTrackingRef(t *testing.T) {
tests := []struct {
name string
cmdStubs func(*run.CommandStubber)
headBranchConfig git.BranchConfig
remotes context.Remotes
expectedTrackingRef trackingRef
expectedFound bool
@ -1633,18 +1633,18 @@ func Test_tryDetermineTrackingRef(t *testing.T) {
{
name: "empty",
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "")
cs.Register(`git show-ref --verify -- HEAD`, 0, "abc HEAD")
},
headBranchConfig: git.BranchConfig{},
expectedTrackingRef: trackingRef{},
expectedFound: false,
},
{
name: "no match",
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "")
cs.Register("git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature", 0, "abc HEAD\nbca refs/remotes/upstream/feature")
},
headBranchConfig: git.BranchConfig{},
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "upstream"},
@ -1661,13 +1661,13 @@ func Test_tryDetermineTrackingRef(t *testing.T) {
{
name: "match",
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, "")
cs.Register(`git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature$`, 0, heredoc.Doc(`
deadbeef HEAD
deadb00f refs/remotes/upstream/feature
deadbeef refs/remotes/origin/feature
`))
},
headBranchConfig: git.BranchConfig{},
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "upstream"},
@ -1687,15 +1687,15 @@ func Test_tryDetermineTrackingRef(t *testing.T) {
{
name: "respect tracking config",
cmdStubs: func(cs *run.CommandStubber) {
cs.Register(`git config --get-regexp.+branch\\\.feature\\\.`, 0, heredoc.Doc(`
branch.feature.remote origin
branch.feature.merge refs/heads/great-feat
`))
cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/great-feat refs/remotes/origin/feature$`, 0, heredoc.Doc(`
deadbeef HEAD
deadb00f refs/remotes/origin/feature
`))
},
headBranchConfig: git.BranchConfig{
RemoteName: "origin",
MergeRef: "refs/heads/great-feat",
},
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
@ -1717,8 +1717,8 @@ func Test_tryDetermineTrackingRef(t *testing.T) {
GhPath: "some/path/gh",
GitPath: "some/path/git",
}
headBranchConfig := gitClient.ReadBranchConfig(ctx.Background(), "feature")
ref, found := tryDetermineTrackingRef(gitClient, tt.remotes, "feature", headBranchConfig)
ref, found := tryDetermineTrackingRef(gitClient, tt.remotes, "feature", tt.headBranchConfig)
assert.Equal(t, tt.expectedTrackingRef, ref)
assert.Equal(t, tt.expectedFound, found)

View file

@ -37,7 +37,7 @@ type finder struct {
branchFn func() (string, error)
remotesFn func() (remotes.Remotes, error)
httpClient func() (*http.Client, error)
branchConfig func(string) git.BranchConfig
branchConfig func(string) (git.BranchConfig, error)
progress progressIndicator
repo ghrepo.Interface
@ -58,7 +58,7 @@ func NewFinder(factory *cmdutil.Factory) PRFinder {
remotesFn: factory.Remotes,
httpClient: factory.HttpClient,
progress: factory.IOStreams,
branchConfig: func(s string) git.BranchConfig {
branchConfig: func(s string) (git.BranchConfig, error) {
return factory.GitClient.ReadBranchConfig(context.Background(), s)
},
}
@ -238,7 +238,10 @@ func (f *finder) parseCurrentBranch() (string, int, error) {
return "", 0, err
}
branchConfig := f.branchConfig(prHeadRef)
branchConfig, err := f.branchConfig(prHeadRef)
if err != nil {
return "", 0, err
}
// the branch is configured to merge a special PR head ref
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {

View file

@ -10,19 +10,21 @@ import (
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type args struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
branchConfig func(string) (git.BranchConfig, error)
remotesFn func() (context.Remotes, error)
selector string
fields []string
baseBranch string
}
func TestFind(t *testing.T) {
type args struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)
branchConfig func(string) git.BranchConfig
remotesFn func() (context.Remotes, error)
selector string
fields []string
baseBranch string
}
tests := []struct {
name string
args args
@ -231,9 +233,7 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (c git.BranchConfig) {
return
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -265,9 +265,7 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (c git.BranchConfig) {
return
},
branchConfig: stubBranchConfig(git.BranchConfig{}, nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -315,11 +313,10 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (c git.BranchConfig) {
c.MergeRef = "refs/heads/blue-upstream-berries"
c.RemoteName = "origin"
return
},
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/heads/blue-upstream-berries",
RemoteName: "origin",
}, nil),
remotesFn: func() (context.Remotes, error) {
return context.Remotes{{
Remote: &git.Remote{Name: "origin"},
@ -357,11 +354,12 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (c git.BranchConfig) {
branchConfig: func(branch string) (git.BranchConfig, error) {
u, _ := url.Parse("https://github.com/UPSTREAMOWNER/REPO")
c.MergeRef = "refs/heads/blue-upstream-berries"
c.RemoteURL = u
return
return stubBranchConfig(git.BranchConfig{
MergeRef: "refs/heads/blue-upstream-berries",
RemoteURL: u,
}, nil)(branch)
},
remotesFn: nil,
},
@ -395,10 +393,9 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (c git.BranchConfig) {
c.RemoteName = "origin"
return
},
branchConfig: stubBranchConfig(git.BranchConfig{
RemoteName: "origin",
}, nil),
remotesFn: func() (context.Remotes, error) {
return context.Remotes{{
Remote: &git.Remote{Name: "origin"},
@ -436,10 +433,9 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (c git.BranchConfig) {
c.MergeRef = "refs/pull/13/head"
return
},
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/pull/13/head",
}, nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -462,10 +458,9 @@ func TestFind(t *testing.T) {
branchFn: func() (string, error) {
return "blueberries", nil
},
branchConfig: func(branch string) (c git.BranchConfig) {
c.MergeRef = "refs/pull/13/head"
return
},
branchConfig: stubBranchConfig(git.BranchConfig{
MergeRef: "refs/pull/13/head",
}, nil),
},
httpStub: func(r *httpmock.Registry) {
r.Register(
@ -561,3 +556,49 @@ func TestFind(t *testing.T) {
})
}
}
func Test_parseCurrentBranch(t *testing.T) {
tests := []struct {
name string
args args
wantSelector string
wantPR int
wantError error
}{
{
name: "failed branch config",
args: args{
branchConfig: stubBranchConfig(git.BranchConfig{}, errors.New("branchConfigErr")),
branchFn: func() (string, error) {
return "blueberries", nil
},
},
wantSelector: "",
wantPR: 0,
wantError: errors.New("branchConfigErr"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := finder{
httpClient: func() (*http.Client, error) {
return &http.Client{}, nil
},
baseRepoFn: tt.args.baseRepoFn,
branchFn: tt.args.branchFn,
branchConfig: tt.args.branchConfig,
remotesFn: tt.args.remotesFn,
}
selector, pr, err := f.parseCurrentBranch()
assert.Equal(t, tt.wantSelector, selector)
assert.Equal(t, tt.wantPR, pr)
assert.Equal(t, tt.wantError, err)
})
}
}
func stubBranchConfig(branchConfig git.BranchConfig, err error) func(string) (git.BranchConfig, error) {
return func(branch string) (git.BranchConfig, error) {
return branchConfig, err
}
}

View file

@ -72,6 +72,7 @@ func NewCmdStatus(f *cmdutil.Factory, runF func(*StatusOptions) error) *cobra.Co
}
func statusRun(opts *StatusOptions) error {
ctx := context.Background()
httpClient, err := opts.HttpClient()
if err != nil {
return err
@ -93,7 +94,11 @@ func statusRun(opts *StatusOptions) error {
}
remotes, _ := opts.Remotes()
currentPRNumber, currentPRHeadRef, err = prSelectorForCurrentBranch(opts.GitClient, baseRepo, currentBranch, remotes)
branchConfig, err := opts.GitClient.ReadBranchConfig(ctx, currentBranch)
if err != nil {
return err
}
currentPRNumber, currentPRHeadRef, err = prSelectorForCurrentBranch(branchConfig, baseRepo, currentBranch, remotes)
if err != nil {
return fmt.Errorf("could not query for pull request for current branch: %w", err)
}
@ -184,31 +189,42 @@ func statusRun(opts *StatusOptions) error {
return nil
}
func prSelectorForCurrentBranch(gitClient *git.Client, baseRepo ghrepo.Interface, prHeadRef string, rem ghContext.Remotes) (prNumber int, selector string, err error) {
selector = prHeadRef
branchConfig := gitClient.ReadBranchConfig(context.Background(), prHeadRef)
func prSelectorForCurrentBranch(branchConfig git.BranchConfig, baseRepo ghrepo.Interface, prHeadRef string, rem ghContext.Remotes) (int, string, error) {
// the branch is configured to merge a special PR head ref
prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`)
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
prNumber, _ = strconv.Atoi(m[1])
return
prNumber, err := strconv.Atoi(m[1])
if err != nil {
return 0, "", err
}
return prNumber, prHeadRef, nil
}
var branchOwner string
if branchConfig.RemoteURL != nil {
// the branch merges from a remote specified by URL
if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil {
branchOwner = r.RepoOwner()
r, err := ghrepo.FromURL(branchConfig.RemoteURL)
if err != nil {
// TODO: We aren't returning the error because we discovered that it was shadowed
// before refactoring to its current return pattern. Thus, we aren't confident
// that returning the error won't break existing behavior.
return 0, prHeadRef, nil
}
branchOwner = r.RepoOwner()
} else if branchConfig.RemoteName != "" {
// the branch merges from a remote specified by name
if r, err := rem.FindByName(branchConfig.RemoteName); err == nil {
branchOwner = r.RepoOwner()
r, err := rem.FindByName(branchConfig.RemoteName)
if err != nil {
// TODO: We aren't returning the error because we discovered that it was shadowed
// before refactoring to its current return pattern. Thus, we aren't confident
// that returning the error won't break existing behavior.
return 0, prHeadRef, nil
}
branchOwner = r.RepoOwner()
}
if branchOwner != "" {
selector := prHeadRef
if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") {
selector = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
}
@ -216,9 +232,10 @@ func prSelectorForCurrentBranch(gitClient *git.Client, baseRepo ghrepo.Interface
if !strings.EqualFold(branchOwner, baseRepo.RepoOwner()) {
selector = fmt.Sprintf("%s:%s", branchOwner, selector)
}
return 0, selector, nil
}
return
return 0, prHeadRef, nil
}
func totalApprovals(pr *api.PullRequest) int {

View file

@ -4,11 +4,11 @@ import (
"bytes"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"testing"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/config"
@ -21,6 +21,7 @@ import (
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/cli/cli/v2/test"
"github.com/google/shlex"
"github.com/stretchr/testify/assert"
)
func runCommand(rt http.RoundTripper, branch string, isTTY bool, cli string) (*test.CmdOut, error) {
@ -95,6 +96,11 @@ func TestPRStatus(t *testing.T) {
defer http.Verify(t)
http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.FileResponse("./fixtures/prStatus.json"))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
t.Errorf("error running command `pr status`: %v", err)
@ -120,6 +126,11 @@ func TestPRStatus_reviewsAndChecks(t *testing.T) {
// status,conclusion matches the old StatusContextRollup query
http.Register(httpmock.GraphQL(`status,conclusion`), httpmock.FileResponse("./fixtures/prStatusChecks.json"))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
t.Errorf("error running command `pr status`: %v", err)
@ -145,6 +156,11 @@ func TestPRStatus_reviewsAndChecksWithStatesByCount(t *testing.T) {
// checkRunCount,checkRunCountsByState matches the new StatusContextRollup query
http.Register(httpmock.GraphQL(`checkRunCount,checkRunCountsByState`), httpmock.FileResponse("./fixtures/prStatusChecksWithStatesByCount.json"))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
output, err := runCommandWithDetector(http, "blueberries", true, "", &fd.EnabledDetectorMock{})
if err != nil {
t.Errorf("error running command `pr status`: %v", err)
@ -169,6 +185,11 @@ func TestPRStatus_currentBranch_showTheMostRecentPR(t *testing.T) {
defer http.Verify(t)
http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.FileResponse("./fixtures/prStatusCurrentBranch.json"))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
t.Errorf("error running command `pr status`: %v", err)
@ -197,6 +218,11 @@ func TestPRStatus_currentBranch_defaultBranch(t *testing.T) {
defer http.Verify(t)
http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.FileResponse("./fixtures/prStatusCurrentBranch.json"))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
t.Errorf("error running command `pr status`: %v", err)
@ -231,6 +257,11 @@ func TestPRStatus_currentBranch_Closed(t *testing.T) {
defer http.Verify(t)
http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.FileResponse("./fixtures/prStatusCurrentBranchClosed.json"))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
t.Errorf("error running command `pr status`: %v", err)
@ -248,6 +279,11 @@ func TestPRStatus_currentBranch_Closed_defaultBranch(t *testing.T) {
defer http.Verify(t)
http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.FileResponse("./fixtures/prStatusCurrentBranchClosedOnDefaultBranch.json"))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
t.Errorf("error running command `pr status`: %v", err)
@ -265,6 +301,11 @@ func TestPRStatus_currentBranch_Merged(t *testing.T) {
defer http.Verify(t)
http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.FileResponse("./fixtures/prStatusCurrentBranchMerged.json"))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
t.Errorf("error running command `pr status`: %v", err)
@ -282,6 +323,11 @@ func TestPRStatus_currentBranch_Merged_defaultBranch(t *testing.T) {
defer http.Verify(t)
http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.FileResponse("./fixtures/prStatusCurrentBranchMergedOnDefaultBranch.json"))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
t.Errorf("error running command `pr status`: %v", err)
@ -299,6 +345,11 @@ func TestPRStatus_blankSlate(t *testing.T) {
defer http.Verify(t)
http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.StringResponse(`{"data": {}}`))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
output, err := runCommand(http, "blueberries", true, "")
if err != nil {
t.Errorf("error running command `pr status`: %v", err)
@ -352,6 +403,11 @@ func TestPRStatus_detachedHead(t *testing.T) {
defer http.Verify(t)
http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.StringResponse(`{"data": {}}`))
// stub successful git command
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 0, "")
output, err := runCommand(http, "", true, "")
if err != nil {
t.Errorf("error running command `pr status`: %v", err)
@ -375,31 +431,219 @@ Requesting a code review from you
}
}
func Test_prSelectorForCurrentBranch(t *testing.T) {
func TestPRStatus_error_ReadBranchConfig(t *testing.T) {
rs, cleanup := run.Stub()
defer cleanup(t)
rs.Register(`git config --get-regexp \^branch\\.`, 1, "")
rs.Register(`git config --get-regexp \^branch\\.`, 0, heredoc.Doc(`
branch.Frederick888/main.remote git@github.com:Frederick888/playground.git
branch.Frederick888/main.merge refs/heads/main
`))
_, err := runCommand(initFakeHTTP(), "blueberries", true, "")
assert.Error(t, err)
}
repo := ghrepo.NewWithHost("octocat", "playground", "github.com")
rem := context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: repo,
func Test_prSelectorForCurrentBranch(t *testing.T) {
tests := []struct {
name string
branchConfig git.BranchConfig
baseRepo ghrepo.Interface
prHeadRef string
remotes context.Remotes
wantPrNumber int
wantSelector string
wantError error
}{
{
name: "Empty branch config",
branchConfig: git.BranchConfig{},
prHeadRef: "monalisa/main",
wantPrNumber: 0,
wantSelector: "monalisa/main",
wantError: nil,
},
{
name: "The branch is configured to merge a special PR head ref",
branchConfig: git.BranchConfig{
MergeRef: "refs/pull/42/head",
},
prHeadRef: "monalisa/main",
wantPrNumber: 42,
wantSelector: "monalisa/main",
wantError: nil,
},
{
name: "Branch merges from a remote specified by URL",
branchConfig: git.BranchConfig{
RemoteURL: &url.URL{
Scheme: "ssh",
User: url.User("git"),
Host: "github.com",
Path: "monalisa/playground.git",
},
},
baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
prHeadRef: "monalisa/main",
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
},
},
wantPrNumber: 0,
wantSelector: "monalisa/main",
wantError: nil,
},
{
name: "Branch merges from a remote specified by name",
branchConfig: git.BranchConfig{
RemoteName: "upstream",
},
baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
prHeadRef: "monalisa/main",
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"),
},
&context.Remote{
Remote: &git.Remote{Name: "upstream"},
Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
},
},
wantPrNumber: 0,
wantSelector: "monalisa/main",
wantError: nil,
},
{
name: "Branch is a fork and merges from a remote specified by URL",
branchConfig: git.BranchConfig{
RemoteURL: &url.URL{
Scheme: "ssh",
User: url.User("git"),
Host: "github.com",
Path: "forkName/playground.git",
},
MergeRef: "refs/heads/main",
},
baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
prHeadRef: "monalisa/main",
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"),
},
},
wantPrNumber: 0,
wantSelector: "forkName:main",
wantError: nil,
},
{
name: "Branch is a fork and merges from a remote specified by name",
branchConfig: git.BranchConfig{
RemoteName: "origin",
},
baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
prHeadRef: "monalisa/main",
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"),
},
&context.Remote{
Remote: &git.Remote{Name: "upstream"},
Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
},
},
wantPrNumber: 0,
wantSelector: "forkName:monalisa/main",
wantError: nil,
},
{
name: "Branch specifies a mergeRef and merges from a remote specified by name",
branchConfig: git.BranchConfig{
RemoteName: "upstream",
MergeRef: "refs/heads/main",
},
baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
prHeadRef: "monalisa/main",
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"),
},
&context.Remote{
Remote: &git.Remote{Name: "upstream"},
Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
},
},
wantPrNumber: 0,
wantSelector: "main",
wantError: nil,
},
{
name: "Branch is a fork, specifies a mergeRef, and merges from a remote specified by name",
branchConfig: git.BranchConfig{
RemoteName: "origin",
MergeRef: "refs/heads/main",
},
baseRepo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
prHeadRef: "monalisa/main",
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"),
},
&context.Remote{
Remote: &git.Remote{Name: "upstream"},
Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
},
},
wantPrNumber: 0,
wantSelector: "forkName:main",
wantError: nil,
},
{
name: "Remote URL errors",
branchConfig: git.BranchConfig{
RemoteURL: &url.URL{
Scheme: "ssh",
User: url.User("git"),
Host: "github.com",
Path: "/\\invalid?Path/",
},
},
prHeadRef: "monalisa/main",
wantPrNumber: 0,
wantSelector: "monalisa/main",
wantError: nil,
},
{
name: "Remote Name errors",
branchConfig: git.BranchConfig{
RemoteName: "nonexistentRemote",
},
prHeadRef: "monalisa/main",
remotes: context.Remotes{
&context.Remote{
Remote: &git.Remote{Name: "origin"},
Repo: ghrepo.NewWithHost("forkName", "playground", "github.com"),
},
&context.Remote{
Remote: &git.Remote{Name: "upstream"},
Repo: ghrepo.NewWithHost("monalisa", "playground", "github.com"),
},
},
wantPrNumber: 0,
wantSelector: "monalisa/main",
wantError: nil,
},
}
gitClient := &git.Client{GitPath: "some/path/git"}
prNum, headRef, err := prSelectorForCurrentBranch(gitClient, repo, "Frederick888/main", rem)
if err != nil {
t.Fatalf("prSelectorForCurrentBranch error: %v", err)
}
if prNum != 0 {
t.Errorf("expected prNum to be 0, got %q", prNum)
}
if headRef != "Frederick888:main" {
t.Errorf("expected headRef to be \"Frederick888:main\", got %q", headRef)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prNum, headRef, err := prSelectorForCurrentBranch(tt.branchConfig, tt.baseRepo, tt.prHeadRef, tt.remotes)
assert.Equal(t, tt.wantPrNumber, prNum)
assert.Equal(t, tt.wantSelector, headRef)
assert.Equal(t, tt.wantError, err)
})
}
}

View file

@ -52,20 +52,25 @@ func NewCmdExtension(io *iostreams.IOStreams, em extensions.ExtensionManager, ex
},
// PostRun handles communicating extension release information if found
PostRun: func(c *cobra.Command, args []string) {
releaseInfo := <-updateMessageChan
if releaseInfo != nil {
stderr := io.ErrOut
fmt.Fprintf(stderr, "\n\n%s %s → %s\n",
cs.Yellowf("A new release of %s is available:", ext.Name()),
cs.Cyan(strings.TrimPrefix(ext.CurrentVersion(), "v")),
cs.Cyan(strings.TrimPrefix(releaseInfo.Version, "v")))
if ext.IsPinned() {
fmt.Fprintf(stderr, "To upgrade, run: gh extension upgrade %s --force\n", ext.Name())
} else {
fmt.Fprintf(stderr, "To upgrade, run: gh extension upgrade %s\n", ext.Name())
select {
case releaseInfo := <-updateMessageChan:
if releaseInfo != nil {
stderr := io.ErrOut
fmt.Fprintf(stderr, "\n\n%s %s → %s\n",
cs.Yellowf("A new release of %s is available:", ext.Name()),
cs.Cyan(strings.TrimPrefix(ext.CurrentVersion(), "v")),
cs.Cyan(strings.TrimPrefix(releaseInfo.Version, "v")))
if ext.IsPinned() {
fmt.Fprintf(stderr, "To upgrade, run: gh extension upgrade %s --force\n", ext.Name())
} else {
fmt.Fprintf(stderr, "To upgrade, run: gh extension upgrade %s\n", ext.Name())
}
fmt.Fprintf(stderr, "%s\n\n",
cs.Yellow(releaseInfo.URL))
}
fmt.Fprintf(stderr, "%s\n\n",
cs.Yellow(releaseInfo.URL))
default:
// Do not make the user wait for extension update check if incomplete by this time.
// This is being handled in non-blocking default as there is no context to cancel like in gh update checks.
}
},
GroupID: "extension",

View file

@ -1,8 +1,10 @@
package root_test
import (
"fmt"
"io"
"testing"
"time"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/internal/update"
@ -121,6 +123,10 @@ func TestNewCmdExtension_Updates(t *testing.T) {
em := &extensions.ExtensionManagerMock{
DispatchFunc: func(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) (bool, error) {
// Assume extension executed / dispatched without problems as test is focused on upgrade checking.
// Sleep for 100 milliseconds to allow update checking logic to complete. This would be better
// served by making the behaviour controllable by channels, but it's a larger change than desired
// just to improve the test.
time.Sleep(100 * time.Millisecond)
return true, nil
},
}
@ -169,3 +175,62 @@ func TestNewCmdExtension_Updates(t *testing.T) {
}
}
}
func TestNewCmdExtension_UpdateCheckIsNonblocking(t *testing.T) {
ios, _, _, _ := iostreams.Test()
em := &extensions.ExtensionManagerMock{
DispatchFunc: func(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) (bool, error) {
// Assume extension executed / dispatched without problems as test is focused on upgrade checking.
return true, nil
},
}
ext := &extensions.ExtensionMock{
CurrentVersionFunc: func() string {
return "1.0.0"
},
IsPinnedFunc: func() bool {
return false
},
LatestVersionFunc: func() string {
return "2.0.0"
},
NameFunc: func() string {
return "major-update"
},
UpdateAvailableFunc: func() bool {
return true
},
URLFunc: func() string {
return "https//github.com/dne/major-update"
},
}
// When the extension command is executed, the checkFunc will run in the background longer than the extension dispatch.
// If the update check is non-blocking, then the extension command will complete immediately while checkFunc is still running.
checkFunc := func(em extensions.ExtensionManager, ext extensions.Extension) (*update.ReleaseInfo, error) {
time.Sleep(30 * time.Second)
return nil, fmt.Errorf("update check should not have completed")
}
cmd := root.NewCmdExtension(ios, em, ext, checkFunc)
// The test whether update check is non-blocking is based on how long it takes for the extension command execution.
// If there is no wait time as checkFunc is sleeping sufficiently long, we can trust update check is non-blocking.
// Otherwise, if any amount of wait is encountered, it is a decent indicator that update checking is blocking.
// This is not an ideal test and indicates the update design should be revisited to be easier to understand and manage.
completed := make(chan struct{})
go func() {
_, err := cmd.ExecuteC()
require.NoError(t, err)
close(completed)
}()
select {
case <-completed:
// Expected behavior assuming extension dispatch exits immediately while checkFunc is still running.
case <-time.After(1 * time.Second):
t.Fatal("extension update check should have exited")
}
}