From a78c06970a6b351da65f449d61278b0838e774a5 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Mon, 24 Mar 2025 17:28:00 -0600 Subject: [PATCH 01/51] pass predicate type to get attestation api methods Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/api/client.go | 21 ++++++++++-------- pkg/cmd/attestation/api/client_test.go | 30 +++++++++++++------------- pkg/cmd/attestation/api/mock_client.go | 20 ++++++++--------- 3 files changed, 37 insertions(+), 34 deletions(-) diff --git a/pkg/cmd/attestation/api/client.go b/pkg/cmd/attestation/api/client.go index 1e99a2a06..8c1af495f 100644 --- a/pkg/cmd/attestation/api/client.go +++ b/pkg/cmd/attestation/api/client.go @@ -39,8 +39,8 @@ type httpClient interface { } type Client interface { - GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) - GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) + GetByRepoAndDigest(repo, digest, predicateType string, limit int) ([]*Attestation, error) + GetByOwnerAndDigest(owner, digest, predicateType string, limit int) ([]*Attestation, error) GetTrustDomain() (string, error) } @@ -61,21 +61,21 @@ func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClien } // GetByRepoAndDigest fetches the attestation by repo and digest -func (c *LiveClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) { +func (c *LiveClient) GetByRepoAndDigest(repo, digest, predicateType string, limit int) ([]*Attestation, error) { c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) url := fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, repo, digest) - return c.getByURL(url, limit) + return c.getByURL(url, predicateType, limit) } // GetByOwnerAndDigest fetches attestation by owner and digest -func (c *LiveClient) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) { +func (c *LiveClient) GetByOwnerAndDigest(owner, digest, predicateType string, limit int) ([]*Attestation, error) { c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) url := fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, owner, digest) - return c.getByURL(url, limit) + return c.getByURL(url, predicateType, limit) } -func (c *LiveClient) getByURL(url string, limit int) ([]*Attestation, error) { - attestations, err := c.getAttestations(url, limit) +func (c *LiveClient) getByURL(url, predicateType string, limit int) ([]*Attestation, error) { + attestations, err := c.getAttestations(url, predicateType, limit) if err != nil { return nil, err } @@ -94,7 +94,7 @@ func (c *LiveClient) GetTrustDomain() (string, error) { return c.getTrustDomain(MetaPath) } -func (c *LiveClient) getAttestations(url string, limit int) ([]*Attestation, error) { +func (c *LiveClient) getAttestations(url, predicateType string, limit int) ([]*Attestation, error) { perPage := limit if perPage <= 0 || perPage > maxLimitForFlag { return nil, fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag) @@ -106,6 +106,9 @@ func (c *LiveClient) getAttestations(url string, limit int) ([]*Attestation, err // ref: https://github.com/cli/go-gh/blob/d32c104a9a25c9de3d7c7b07a43ae0091441c858/example_gh_test.go#L96 url = fmt.Sprintf("%s?per_page=%d", url, perPage) + if predicateType != "" { + url = fmt.Sprintf("%s&predicate_type=%s", url, predicateType) + } var attestations []*Attestation var resp AttestationsResponse diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index 787408a4e..2a62d5662 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -44,14 +44,14 @@ func NewClientWithMockGHClient(hasNextPage bool) Client { func TestGetByDigest(t *testing.T) { c := NewClientWithMockGHClient(false) - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) require.NoError(t, err) require.Equal(t, 5, len(attestations)) bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) + attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) require.NoError(t, err) require.Equal(t, 5, len(attestations)) @@ -64,14 +64,14 @@ func TestGetByDigestGreaterThanLimit(t *testing.T) { limit := 3 // The method should return five results when the limit is not set - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, limit) + attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", limit) require.NoError(t, err) require.Equal(t, 3, len(attestations)) bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, limit) + attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", limit) require.NoError(t, err) require.Equal(t, len(attestations), limit) @@ -81,14 +81,14 @@ func TestGetByDigestGreaterThanLimit(t *testing.T) { func TestGetByDigestWithNextPage(t *testing.T) { c := NewClientWithMockGHClient(true) - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) require.NoError(t, err) require.Equal(t, len(attestations), 10) bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) + attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) require.NoError(t, err) require.Equal(t, len(attestations), 10) @@ -101,14 +101,14 @@ func TestGetByDigestGreaterThanLimitWithNextPage(t *testing.T) { limit := 7 // The method should return five results when the limit is not set - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, limit) + attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", limit) require.NoError(t, err) require.Equal(t, len(attestations), limit) bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, limit) + attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", limit) require.NoError(t, err) require.Equal(t, len(attestations), limit) @@ -130,12 +130,12 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) require.Error(t, err) require.IsType(t, ErrNoAttestationsFound, err) require.Nil(t, attestations) - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) + attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) require.Error(t, err) require.IsType(t, ErrNoAttestationsFound, err) require.Nil(t, attestations) @@ -153,11 +153,11 @@ func TestGetByDigest_Error(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) require.Error(t, err) require.Nil(t, attestations) - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) + attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) require.Error(t, err) require.Nil(t, attestations) } @@ -362,7 +362,7 @@ func TestGetAttestationsRetries(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) require.NoError(t, err) // assert the error path was executed; because this is a paged @@ -375,7 +375,7 @@ func TestGetAttestationsRetries(t *testing.T) { require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") // same test as above, but for GetByOwnerAndDigest: - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) + attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) require.NoError(t, err) // because we haven't reset the mock, we have added 2 more failed requests @@ -401,7 +401,7 @@ func TestGetAttestationsMaxRetries(t *testing.T) { logger: io.NewTestHandler(), } - _, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + _, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) require.Error(t, err) fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4) diff --git a/pkg/cmd/attestation/api/mock_client.go b/pkg/cmd/attestation/api/mock_client.go index b2fd334c0..8e6dcdd6f 100644 --- a/pkg/cmd/attestation/api/mock_client.go +++ b/pkg/cmd/attestation/api/mock_client.go @@ -7,17 +7,17 @@ import ( ) type MockClient struct { - OnGetByRepoAndDigest func(repo, digest string, limit int) ([]*Attestation, error) - OnGetByOwnerAndDigest func(owner, digest string, limit int) ([]*Attestation, error) + OnGetByRepoAndDigest func(repo, digest, predicateType string, limit int) ([]*Attestation, error) + OnGetByOwnerAndDigest func(owner, digest, predicateType string, limit int) ([]*Attestation, error) OnGetTrustDomain func() (string, error) } -func (m MockClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) { - return m.OnGetByRepoAndDigest(repo, digest, limit) +func (m MockClient) GetByRepoAndDigest(repo, digest, predicateType string, limit int) ([]*Attestation, error) { + return m.OnGetByRepoAndDigest(repo, digest, predicateType, limit) } -func (m MockClient) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) { - return m.OnGetByOwnerAndDigest(owner, digest, limit) +func (m MockClient) GetByOwnerAndDigest(owner, digest, predicateType string, limit int) ([]*Attestation, error) { + return m.OnGetByOwnerAndDigest(owner, digest, predicateType, limit) } func (m MockClient) GetTrustDomain() (string, error) { @@ -28,23 +28,23 @@ func makeTestAttestation() Attestation { return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"} } -func OnGetByRepoAndDigestSuccess(repo, digest string, limit int) ([]*Attestation, error) { +func OnGetByRepoAndDigestSuccess(repo, digest, predicateType string, limit int) ([]*Attestation, error) { att1 := makeTestAttestation() att2 := makeTestAttestation() return []*Attestation{&att1, &att2}, nil } -func OnGetByRepoAndDigestFailure(repo, digest string, limit int) ([]*Attestation, error) { +func OnGetByRepoAndDigestFailure(repo, digest, predicateType string, limit int) ([]*Attestation, error) { return nil, fmt.Errorf("failed to fetch by repo and digest") } -func OnGetByOwnerAndDigestSuccess(owner, digest string, limit int) ([]*Attestation, error) { +func OnGetByOwnerAndDigestSuccess(owner, digest, predicateType string, limit int) ([]*Attestation, error) { att1 := makeTestAttestation() att2 := makeTestAttestation() return []*Attestation{&att1, &att2}, nil } -func OnGetByOwnerAndDigestFailure(owner, digest string, limit int) ([]*Attestation, error) { +func OnGetByOwnerAndDigestFailure(owner, digest, predicateType string, limit int) ([]*Attestation, error) { return nil, fmt.Errorf("failed to fetch by owner and digest") } From faef81f4bc7bea7fdee721a700dd563d6669f31f Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Mon, 24 Mar 2025 17:28:50 -0600 Subject: [PATCH 02/51] reorganize getAttestations func to check for remote gh api fetching first Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/download/download_test.go | 4 +- .../attestation/verification/attestation.go | 13 ++++--- pkg/cmd/attestation/verify/attestation.go | 39 ++++++++++--------- pkg/cmd/attestation/verify/options.go | 6 +++ 4 files changed, 35 insertions(+), 27 deletions(-) diff --git a/pkg/cmd/attestation/download/download_test.go b/pkg/cmd/attestation/download/download_test.go index ddcd08c92..899d15339 100644 --- a/pkg/cmd/attestation/download/download_test.go +++ b/pkg/cmd/attestation/download/download_test.go @@ -275,7 +275,7 @@ func TestRunDownload(t *testing.T) { t.Run("no attestations found", func(t *testing.T) { opts := baseOpts opts.APIClient = api.MockClient{ - OnGetByOwnerAndDigest: func(repo, digest string, limit int) ([]*api.Attestation, error) { + OnGetByOwnerAndDigest: func(repo, digest, predicateType string, limit int) ([]*api.Attestation, error) { return nil, api.ErrNoAttestationsFound }, } @@ -291,7 +291,7 @@ func TestRunDownload(t *testing.T) { t.Run("failed to fetch attestations", func(t *testing.T) { opts := baseOpts opts.APIClient = api.MockClient{ - OnGetByOwnerAndDigest: func(repo, digest string, limit int) ([]*api.Attestation, error) { + OnGetByOwnerAndDigest: func(repo, digest, predicateType string, limit int) ([]*api.Attestation, error) { return nil, fmt.Errorf("failed to fetch attestations") }, } diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index db419ebac..53a53722b 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -21,10 +21,11 @@ var ErrUnrecognisedBundleExtension = errors.New("bundle file extension not suppo var ErrEmptyBundleFile = errors.New("provided bundle file is empty") type FetchRemoteAttestationsParams struct { - Digest string - Limit int - Owner string - Repo string + Digest string + Limit int + Owner string + PredicateType string + Repo string } // GetLocalAttestations returns a slice of attestations read from a local bundle file. @@ -96,13 +97,13 @@ func GetRemoteAttestations(client api.Client, params FetchRemoteAttestationsPara // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. if params.Repo != "" { - attestations, err := client.GetByRepoAndDigest(params.Repo, params.Digest, params.Limit) + attestations, err := client.GetByRepoAndDigest(params.Repo, params.Digest, params.PredicateType, params.Limit) if err != nil { return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err) } return attestations, nil } else if params.Owner != "" { - attestations, err := client.GetByOwnerAndDigest(params.Owner, params.Digest, params.Limit) + attestations, err := client.GetByOwnerAndDigest(params.Owner, params.Digest, params.PredicateType, params.Limit) if err != nil { return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err) } diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index bb96c9526..55fcf1f7e 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -10,7 +10,24 @@ import ( ) func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) { - if o.BundlePath != "" { + if o.FetchAttestationsFromGitHubAPI() { + params := verification.FetchRemoteAttestationsParams{ + Digest: a.DigestWithAlg(), + Limit: o.Limit, + Owner: o.Owner, + PredicateType: o.PredicateType, + Repo: o.Repo, + } + + attestations, err := verification.GetRemoteAttestations(o.APIClient, params) + if err != nil { + msg := "✗ Loading attestations from GitHub API failed" + return nil, msg, err + } + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation) + return attestations, msg, nil + } else if o.BundlePath != "" { attestations, err := verification.GetLocalAttestations(o.BundlePath) if err != nil { msg := fmt.Sprintf("✗ Loading attestations from %s failed", a.URL) @@ -19,9 +36,7 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio pluralAttestation := text.Pluralize(len(attestations), "attestation") msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) return attestations, msg, nil - } - - if o.UseBundleFromRegistry { + } else if o.UseBundleFromRegistry { attestations, err := verification.GetOCIAttestations(o.OCIClient, a) if err != nil { msg := "✗ Loading attestations from OCI registry failed" @@ -32,21 +47,7 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio return attestations, msg, nil } - params := verification.FetchRemoteAttestationsParams{ - Digest: a.DigestWithAlg(), - Limit: o.Limit, - Owner: o.Owner, - Repo: o.Repo, - } - - attestations, err := verification.GetRemoteAttestations(o.APIClient, params) - if err != nil { - msg := "✗ Loading attestations from GitHub API failed" - return nil, msg, err - } - pluralAttestation := text.Pluralize(len(attestations), "attestation") - msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation) - return attestations, msg, nil + return nil, "", fmt.Errorf("no valid attestation source provided") } func verifyAttestations(art artifact.DigestedArtifact, att []*api.Attestation, sgVerifier verification.SigstoreVerifier, ec verification.EnforcementCriteria) ([]*verification.AttestationProcessingResult, string, error) { diff --git a/pkg/cmd/attestation/verify/options.go b/pkg/cmd/attestation/verify/options.go index 0fbbec55a..e47c4f4a8 100644 --- a/pkg/cmd/attestation/verify/options.go +++ b/pkg/cmd/attestation/verify/options.go @@ -53,6 +53,12 @@ func (opts *Options) Clean() { } } +// FetchAttestationsFromGitHubAPI returns true if the command should fetch attestations from the GitHub API +// It checks that a bundle path is not provided and that the "use bundle from registry" flag is not set +func (opts *Options) FetchAttestationsFromGitHubAPI() bool { + return opts.BundlePath == "" && !opts.UseBundleFromRegistry +} + // AreFlagsValid checks that the provided flag combination is valid // and returns an error otherwise func (opts *Options) AreFlagsValid() error { From ad20ef35d9f97f30c352cf613ba167a08918a4ad Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Mon, 24 Mar 2025 17:35:52 -0600 Subject: [PATCH 03/51] move local and oci registry attestation filtering Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/verify/attestation.go | 30 ++++++++++++++++++++--- pkg/cmd/attestation/verify/verify.go | 8 ------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index 55fcf1f7e..5774170d9 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -9,6 +9,16 @@ import ( "github.com/cli/cli/v2/pkg/cmd/attestation/verification" ) +func filterByPredicateType(predicateType string, attestations []*api.Attestation) ([]*api.Attestation, string, error) { + // Apply predicate type filter to returned attestations + filteredAttestations := verification.FilterAttestations(predicateType, attestations) + if len(filteredAttestations) == 0 { + msg := fmt.Sprintf("✗ No attestations found with predicate type: %s\n", predicateType) + return nil, msg, fmt.Errorf("no matching predicate found") + } + return filteredAttestations, "", nil +} + func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) { if o.FetchAttestationsFromGitHubAPI() { params := verification.FetchRemoteAttestationsParams{ @@ -33,18 +43,30 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio msg := fmt.Sprintf("✗ Loading attestations from %s failed", a.URL) return nil, msg, err } - pluralAttestation := text.Pluralize(len(attestations), "attestation") + + filtered, errMsg, err := filterByPredicateType(o.PredicateType, attestations) + if err != nil { + return nil, errMsg, err + } + + pluralAttestation := text.Pluralize(len(filtered), "attestation") msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) - return attestations, msg, nil + return filtered, msg, nil } else if o.UseBundleFromRegistry { attestations, err := verification.GetOCIAttestations(o.OCIClient, a) if err != nil { msg := "✗ Loading attestations from OCI registry failed" return nil, msg, err } - pluralAttestation := text.Pluralize(len(attestations), "attestation") + + filtered, errMsg, err := filterByPredicateType(o.PredicateType, attestations) + if err != nil { + return nil, errMsg, err + } + + pluralAttestation := text.Pluralize(len(filtered), "attestation") msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.ArtifactPath) - return attestations, msg, nil + return filtered, msg, nil } return nil, "", fmt.Errorf("no valid attestation source provided") diff --git a/pkg/cmd/attestation/verify/verify.go b/pkg/cmd/attestation/verify/verify.go index 65ae8ca3e..1de4172b4 100644 --- a/pkg/cmd/attestation/verify/verify.go +++ b/pkg/cmd/attestation/verify/verify.go @@ -235,14 +235,6 @@ func runVerify(opts *Options) error { // Print the message signifying success fetching attestations opts.Logger.Println(logMsg) - // Apply predicate type filter to returned attestations - filteredAttestations := verification.FilterAttestations(ec.PredicateType, attestations) - if len(filteredAttestations) == 0 { - opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ No attestations found with predicate type: %s\n"), opts.PredicateType) - return fmt.Errorf("no matching predicate found") - } - attestations = filteredAttestations - // print information about the policy that will be enforced against attestations opts.Logger.Println("\nThe following policy criteria will be enforced:") opts.Logger.Println(ec.BuildPolicyInformation()) From 95a61974bf4bff0c916fd0e5434b78c614780372 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Mon, 24 Mar 2025 18:01:57 -0600 Subject: [PATCH 04/51] pass params object to api client methods Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/api/client.go | 46 +++++++++------- pkg/cmd/attestation/api/client_test.go | 55 ++++++++++++++----- pkg/cmd/attestation/api/mock_client.go | 20 +++---- pkg/cmd/attestation/download/download.go | 2 +- pkg/cmd/attestation/download/download_test.go | 4 +- .../attestation/verification/attestation.go | 14 +---- pkg/cmd/attestation/verify/attestation.go | 2 +- 7 files changed, 84 insertions(+), 59 deletions(-) diff --git a/pkg/cmd/attestation/api/client.go b/pkg/cmd/attestation/api/client.go index 8c1af495f..ae33d9ce3 100644 --- a/pkg/cmd/attestation/api/client.go +++ b/pkg/cmd/attestation/api/client.go @@ -27,6 +27,15 @@ const ( // Allow injecting backoff interval in tests. var getAttestationRetryInterval = time.Millisecond * 200 +// FetchParams are the parameters for fetching attestations from the GitHub API +type FetchParams struct { + Digest string + Limit int + Owner string + PredicateType string + Repo string +} + // githubApiClient makes REST calls to the GitHub API type githubApiClient interface { REST(hostname, method, p string, body io.Reader, data interface{}) error @@ -39,8 +48,8 @@ type httpClient interface { } type Client interface { - GetByRepoAndDigest(repo, digest, predicateType string, limit int) ([]*Attestation, error) - GetByOwnerAndDigest(owner, digest, predicateType string, limit int) ([]*Attestation, error) + GetByRepoAndDigest(params FetchParams) ([]*Attestation, error) + GetByOwnerAndDigest(params FetchParams) ([]*Attestation, error) GetTrustDomain() (string, error) } @@ -61,21 +70,20 @@ func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClien } // GetByRepoAndDigest fetches the attestation by repo and digest -func (c *LiveClient) GetByRepoAndDigest(repo, digest, predicateType string, limit int) ([]*Attestation, error) { - c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) - url := fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, repo, digest) - return c.getByURL(url, predicateType, limit) +func (c *LiveClient) GetByRepoAndDigest(params FetchParams) ([]*Attestation, error) { + url := fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, params.Repo, params.Digest) + return c.getByURL(url, params) } // GetByOwnerAndDigest fetches attestation by owner and digest -func (c *LiveClient) GetByOwnerAndDigest(owner, digest, predicateType string, limit int) ([]*Attestation, error) { - c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) - url := fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, owner, digest) - return c.getByURL(url, predicateType, limit) +func (c *LiveClient) GetByOwnerAndDigest(params FetchParams) ([]*Attestation, error) { + url := fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, params.Owner, params.Digest) + return c.getByURL(url, params) } -func (c *LiveClient) getByURL(url, predicateType string, limit int) ([]*Attestation, error) { - attestations, err := c.getAttestations(url, predicateType, limit) +func (c *LiveClient) getByURL(url string, params FetchParams) ([]*Attestation, error) { + c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", params.Digest) + attestations, err := c.getAttestations(url, params) if err != nil { return nil, err } @@ -94,8 +102,8 @@ func (c *LiveClient) GetTrustDomain() (string, error) { return c.getTrustDomain(MetaPath) } -func (c *LiveClient) getAttestations(url, predicateType string, limit int) ([]*Attestation, error) { - perPage := limit +func (c *LiveClient) getAttestations(url string, params FetchParams) ([]*Attestation, error) { + perPage := params.Limit if perPage <= 0 || perPage > maxLimitForFlag { return nil, fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag) } @@ -106,8 +114,8 @@ func (c *LiveClient) getAttestations(url, predicateType string, limit int) ([]*A // ref: https://github.com/cli/go-gh/blob/d32c104a9a25c9de3d7c7b07a43ae0091441c858/example_gh_test.go#L96 url = fmt.Sprintf("%s?per_page=%d", url, perPage) - if predicateType != "" { - url = fmt.Sprintf("%s&predicate_type=%s", url, predicateType) + if params.PredicateType != "" { + url = fmt.Sprintf("%s&predicate_type=%s", url, params.PredicateType) } var attestations []*Attestation @@ -115,7 +123,7 @@ func (c *LiveClient) getAttestations(url, predicateType string, limit int) ([]*A bo := backoff.NewConstantBackOff(getAttestationRetryInterval) // if no attestation or less than limit, then keep fetching - for url != "" && len(attestations) < limit { + for url != "" && len(attestations) < params.Limit { err := backoff.Retry(func() error { newURL, restErr := c.githubAPI.RESTWithNext(c.host, http.MethodGet, url, nil, &resp) @@ -143,8 +151,8 @@ func (c *LiveClient) getAttestations(url, predicateType string, limit int) ([]*A return nil, ErrNoAttestationsFound } - if len(attestations) > limit { - return attestations[:limit], nil + if len(attestations) > params.Limit { + return attestations[:params.Limit], nil } return attestations, nil diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index 2a62d5662..f77b39d08 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -42,16 +42,24 @@ func NewClientWithMockGHClient(hasNextPage bool) Client { } } +var testFetchParams = FetchParams{ + Digest: testDigest, + Limit: DefaultLimit, + PredicateType: "https://slsa.dev/provenance/v1", +} + func TestGetByDigest(t *testing.T) { c := NewClientWithMockGHClient(false) - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) + testFetchParams.Repo = testRepo + attestations, err := c.GetByRepoAndDigest(testFetchParams) require.NoError(t, err) require.Equal(t, 5, len(attestations)) bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) + testFetchParams.Owner = testOwner + attestations, err = c.GetByOwnerAndDigest(testFetchParams) require.NoError(t, err) require.Equal(t, 5, len(attestations)) @@ -64,14 +72,17 @@ func TestGetByDigestGreaterThanLimit(t *testing.T) { limit := 3 // The method should return five results when the limit is not set - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", limit) + testFetchParams.Limit = limit + testFetchParams.Repo = testRepo + attestations, err := c.GetByRepoAndDigest(testFetchParams) require.NoError(t, err) require.Equal(t, 3, len(attestations)) bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", limit) + testFetchParams.Owner = testOwner + attestations, err = c.GetByOwnerAndDigest(testFetchParams) require.NoError(t, err) require.Equal(t, len(attestations), limit) @@ -81,14 +92,17 @@ func TestGetByDigestGreaterThanLimit(t *testing.T) { func TestGetByDigestWithNextPage(t *testing.T) { c := NewClientWithMockGHClient(true) - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) + testFetchParams.Repo = testRepo + testFetchParams.Limit = 30 + attestations, err := c.GetByRepoAndDigest(testFetchParams) require.NoError(t, err) require.Equal(t, len(attestations), 10) bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) + testFetchParams.Owner = testOwner + attestations, err = c.GetByOwnerAndDigest(testFetchParams) require.NoError(t, err) require.Equal(t, len(attestations), 10) @@ -101,14 +115,17 @@ func TestGetByDigestGreaterThanLimitWithNextPage(t *testing.T) { limit := 7 // The method should return five results when the limit is not set - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", limit) + testFetchParams.Limit = limit + testFetchParams.Repo = testRepo + attestations, err := c.GetByRepoAndDigest(testFetchParams) require.NoError(t, err) require.Equal(t, len(attestations), limit) bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", limit) + testFetchParams.Owner = testOwner + attestations, err = c.GetByOwnerAndDigest(testFetchParams) require.NoError(t, err) require.Equal(t, len(attestations), limit) @@ -130,12 +147,14 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) + testFetchParams.Repo = testRepo + attestations, err := c.GetByRepoAndDigest(testFetchParams) require.Error(t, err) require.IsType(t, ErrNoAttestationsFound, err) require.Nil(t, attestations) - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) + testFetchParams.Owner = testOwner + attestations, err = c.GetByOwnerAndDigest(testFetchParams) require.Error(t, err) require.IsType(t, ErrNoAttestationsFound, err) require.Nil(t, attestations) @@ -153,11 +172,13 @@ func TestGetByDigest_Error(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) + testFetchParams.Repo = testRepo + attestations, err := c.GetByRepoAndDigest(testFetchParams) require.Error(t, err) require.Nil(t, attestations) - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) + testFetchParams.Owner = testOwner + attestations, err = c.GetByOwnerAndDigest(testFetchParams) require.Error(t, err) require.Nil(t, attestations) } @@ -362,7 +383,9 @@ func TestGetAttestationsRetries(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) + testFetchParams.Repo = testRepo + testFetchParams.Limit = 30 + attestations, err := c.GetByRepoAndDigest(testFetchParams) require.NoError(t, err) // assert the error path was executed; because this is a paged @@ -375,7 +398,8 @@ func TestGetAttestationsRetries(t *testing.T) { require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") // same test as above, but for GetByOwnerAndDigest: - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) + testFetchParams.Owner = testOwner + attestations, err = c.GetByOwnerAndDigest(testFetchParams) require.NoError(t, err) // because we haven't reset the mock, we have added 2 more failed requests @@ -401,7 +425,8 @@ func TestGetAttestationsMaxRetries(t *testing.T) { logger: io.NewTestHandler(), } - _, err := c.GetByRepoAndDigest(testRepo, testDigest, "https://slsa.dev/provenance/v1", DefaultLimit) + testFetchParams.Repo = testRepo + _, err := c.GetByRepoAndDigest(testFetchParams) require.Error(t, err) fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4) diff --git a/pkg/cmd/attestation/api/mock_client.go b/pkg/cmd/attestation/api/mock_client.go index 8e6dcdd6f..be2b9b76d 100644 --- a/pkg/cmd/attestation/api/mock_client.go +++ b/pkg/cmd/attestation/api/mock_client.go @@ -7,17 +7,17 @@ import ( ) type MockClient struct { - OnGetByRepoAndDigest func(repo, digest, predicateType string, limit int) ([]*Attestation, error) - OnGetByOwnerAndDigest func(owner, digest, predicateType string, limit int) ([]*Attestation, error) + OnGetByRepoAndDigest func(params FetchParams) ([]*Attestation, error) + OnGetByOwnerAndDigest func(params FetchParams) ([]*Attestation, error) OnGetTrustDomain func() (string, error) } -func (m MockClient) GetByRepoAndDigest(repo, digest, predicateType string, limit int) ([]*Attestation, error) { - return m.OnGetByRepoAndDigest(repo, digest, predicateType, limit) +func (m MockClient) GetByRepoAndDigest(params FetchParams) ([]*Attestation, error) { + return m.OnGetByRepoAndDigest(params) } -func (m MockClient) GetByOwnerAndDigest(owner, digest, predicateType string, limit int) ([]*Attestation, error) { - return m.OnGetByOwnerAndDigest(owner, digest, predicateType, limit) +func (m MockClient) GetByOwnerAndDigest(params FetchParams) ([]*Attestation, error) { + return m.OnGetByOwnerAndDigest(params) } func (m MockClient) GetTrustDomain() (string, error) { @@ -28,23 +28,23 @@ func makeTestAttestation() Attestation { return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"} } -func OnGetByRepoAndDigestSuccess(repo, digest, predicateType string, limit int) ([]*Attestation, error) { +func OnGetByRepoAndDigestSuccess(params FetchParams) ([]*Attestation, error) { att1 := makeTestAttestation() att2 := makeTestAttestation() return []*Attestation{&att1, &att2}, nil } -func OnGetByRepoAndDigestFailure(repo, digest, predicateType string, limit int) ([]*Attestation, error) { +func OnGetByRepoAndDigestFailure(params FetchParams) ([]*Attestation, error) { return nil, fmt.Errorf("failed to fetch by repo and digest") } -func OnGetByOwnerAndDigestSuccess(owner, digest, predicateType string, limit int) ([]*Attestation, error) { +func OnGetByOwnerAndDigestSuccess(params FetchParams) ([]*Attestation, error) { att1 := makeTestAttestation() att2 := makeTestAttestation() return []*Attestation{&att1, &att2}, nil } -func OnGetByOwnerAndDigestFailure(owner, digest, predicateType string, limit int) ([]*Attestation, error) { +func OnGetByOwnerAndDigestFailure(params FetchParams) ([]*Attestation, error) { return nil, fmt.Errorf("failed to fetch by owner and digest") } diff --git a/pkg/cmd/attestation/download/download.go b/pkg/cmd/attestation/download/download.go index 6913c0787..65b6f83df 100644 --- a/pkg/cmd/attestation/download/download.go +++ b/pkg/cmd/attestation/download/download.go @@ -127,7 +127,7 @@ func runDownload(opts *Options) error { opts.Logger.VerbosePrintf("Downloading trusted metadata for artifact %s\n\n", opts.ArtifactPath) - params := verification.FetchRemoteAttestationsParams{ + params := api.FetchParams{ Digest: artifact.DigestWithAlg(), Limit: opts.Limit, Owner: opts.Owner, diff --git a/pkg/cmd/attestation/download/download_test.go b/pkg/cmd/attestation/download/download_test.go index 899d15339..629de4a66 100644 --- a/pkg/cmd/attestation/download/download_test.go +++ b/pkg/cmd/attestation/download/download_test.go @@ -275,7 +275,7 @@ func TestRunDownload(t *testing.T) { t.Run("no attestations found", func(t *testing.T) { opts := baseOpts opts.APIClient = api.MockClient{ - OnGetByOwnerAndDigest: func(repo, digest, predicateType string, limit int) ([]*api.Attestation, error) { + OnGetByOwnerAndDigest: func(params api.FetchParams) ([]*api.Attestation, error) { return nil, api.ErrNoAttestationsFound }, } @@ -291,7 +291,7 @@ func TestRunDownload(t *testing.T) { t.Run("failed to fetch attestations", func(t *testing.T) { opts := baseOpts opts.APIClient = api.MockClient{ - OnGetByOwnerAndDigest: func(repo, digest, predicateType string, limit int) ([]*api.Attestation, error) { + OnGetByOwnerAndDigest: func(params api.FetchParams) ([]*api.Attestation, error) { return nil, fmt.Errorf("failed to fetch attestations") }, } diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index 53a53722b..c4d5330f6 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -20,14 +20,6 @@ const SLSAPredicateV1 = "https://slsa.dev/provenance/v1" var ErrUnrecognisedBundleExtension = errors.New("bundle file extension not supported, must be json or jsonl") var ErrEmptyBundleFile = errors.New("provided bundle file is empty") -type FetchRemoteAttestationsParams struct { - Digest string - Limit int - Owner string - PredicateType string - Repo string -} - // GetLocalAttestations returns a slice of attestations read from a local bundle file. func GetLocalAttestations(path string) ([]*api.Attestation, error) { var attestations []*api.Attestation @@ -90,20 +82,20 @@ func loadBundlesFromJSONLinesFile(path string) ([]*api.Attestation, error) { return attestations, nil } -func GetRemoteAttestations(client api.Client, params FetchRemoteAttestationsParams) ([]*api.Attestation, error) { +func GetRemoteAttestations(client api.Client, params api.FetchParams) ([]*api.Attestation, error) { if client == nil { return nil, fmt.Errorf("api client must be provided") } // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. if params.Repo != "" { - attestations, err := client.GetByRepoAndDigest(params.Repo, params.Digest, params.PredicateType, params.Limit) + attestations, err := client.GetByRepoAndDigest(params) if err != nil { return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err) } return attestations, nil } else if params.Owner != "" { - attestations, err := client.GetByOwnerAndDigest(params.Owner, params.Digest, params.PredicateType, params.Limit) + attestations, err := client.GetByOwnerAndDigest(params) if err != nil { return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err) } diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index 5774170d9..e8211003c 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -21,7 +21,7 @@ func filterByPredicateType(predicateType string, attestations []*api.Attestation func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) { if o.FetchAttestationsFromGitHubAPI() { - params := verification.FetchRemoteAttestationsParams{ + params := api.FetchParams{ Digest: a.DigestWithAlg(), Limit: o.Limit, Owner: o.Owner, From 5a895b9d72377b7793a1b0084c3ecb6261b15c6e Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Mon, 24 Mar 2025 18:12:41 -0600 Subject: [PATCH 05/51] dedpulicate if else logic Signed-off-by: Meredith Lancaster --- .../attestation/verification/attestation.go | 25 ++++---- pkg/cmd/attestation/verify/attestation.go | 58 +++++++++---------- 2 files changed, 41 insertions(+), 42 deletions(-) diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index c4d5330f6..33d8b18b8 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -88,20 +88,23 @@ func GetRemoteAttestations(client api.Client, params api.FetchParams) ([]*api.At } // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. + var attestations []*api.Attestation + var err error + var owner string if params.Repo != "" { - attestations, err := client.GetByRepoAndDigest(params) - if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err) - } - return attestations, nil + attestations, err = client.GetByRepoAndDigest(params) + owner = params.Repo } else if params.Owner != "" { - attestations, err := client.GetByOwnerAndDigest(params) - if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err) - } - return attestations, nil + attestations, err = client.GetByOwnerAndDigest(params) + owner = params.Owner + } else { + return nil, fmt.Errorf("owner or repo must be provided") } - return nil, fmt.Errorf("owner or repo must be provided") + + if err != nil { + return nil, fmt.Errorf("failed to fetch attestations from %s: %w", owner, err) + } + return attestations, err } func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ([]*api.Attestation, error) { diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index e8211003c..f956e82b8 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -37,39 +37,35 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio pluralAttestation := text.Pluralize(len(attestations), "attestation") msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation) return attestations, msg, nil - } else if o.BundlePath != "" { - attestations, err := verification.GetLocalAttestations(o.BundlePath) - if err != nil { - msg := fmt.Sprintf("✗ Loading attestations from %s failed", a.URL) - return nil, msg, err - } - - filtered, errMsg, err := filterByPredicateType(o.PredicateType, attestations) - if err != nil { - return nil, errMsg, err - } - - pluralAttestation := text.Pluralize(len(filtered), "attestation") - msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) - return filtered, msg, nil - } else if o.UseBundleFromRegistry { - attestations, err := verification.GetOCIAttestations(o.OCIClient, a) - if err != nil { - msg := "✗ Loading attestations from OCI registry failed" - return nil, msg, err - } - - filtered, errMsg, err := filterByPredicateType(o.PredicateType, attestations) - if err != nil { - return nil, errMsg, err - } - - pluralAttestation := text.Pluralize(len(filtered), "attestation") - msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.ArtifactPath) - return filtered, msg, nil } - return nil, "", fmt.Errorf("no valid attestation source provided") + var attestations []*api.Attestation + var err error + var errMsg string + if o.BundlePath != "" { + attestations, err = verification.GetLocalAttestations(o.BundlePath) + if err != nil { + errMsg = fmt.Sprintf("✗ Loading attestations from %s failed", a.URL) + } + } else if o.UseBundleFromRegistry { + attestations, err = verification.GetOCIAttestations(o.OCIClient, a) + if err != nil { + errMsg = "✗ Loading attestations from OCI registry failed" + } + } + + if err != nil { + return nil, errMsg, err + } + + filtered, errMsg, err := filterByPredicateType(o.PredicateType, attestations) + if err != nil { + return nil, errMsg, err + } + + pluralAttestation := text.Pluralize(len(filtered), "attestation") + msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) + return filtered, msg, nil } func verifyAttestations(art artifact.DigestedArtifact, att []*api.Attestation, sgVerifier verification.SigstoreVerifier, ec verification.EnforcementCriteria) ([]*verification.AttestationProcessingResult, string, error) { From a9cc7b481e2ab718163702cfe57565fe83370e19 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Mon, 24 Mar 2025 18:28:27 -0600 Subject: [PATCH 06/51] create single fetch by digest client method Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/api/client.go | 35 ++++++++++----- pkg/cmd/attestation/api/client_test.go | 32 +++++++------- pkg/cmd/attestation/api/mock_client.go | 44 +++++++------------ pkg/cmd/attestation/download/download.go | 2 +- pkg/cmd/attestation/download/download_test.go | 4 +- .../attestation/verification/attestation.go | 25 ----------- pkg/cmd/attestation/verify/attestation.go | 2 +- 7 files changed, 58 insertions(+), 86 deletions(-) diff --git a/pkg/cmd/attestation/api/client.go b/pkg/cmd/attestation/api/client.go index ae33d9ce3..a3e627852 100644 --- a/pkg/cmd/attestation/api/client.go +++ b/pkg/cmd/attestation/api/client.go @@ -48,8 +48,7 @@ type httpClient interface { } type Client interface { - GetByRepoAndDigest(params FetchParams) ([]*Attestation, error) - GetByOwnerAndDigest(params FetchParams) ([]*Attestation, error) + GetByDigest(params FetchParams) ([]*Attestation, error) GetTrustDomain() (string, error) } @@ -69,16 +68,28 @@ func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClien } } -// GetByRepoAndDigest fetches the attestation by repo and digest -func (c *LiveClient) GetByRepoAndDigest(params FetchParams) ([]*Attestation, error) { - url := fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, params.Repo, params.Digest) - return c.getByURL(url, params) -} - -// GetByOwnerAndDigest fetches attestation by owner and digest -func (c *LiveClient) GetByOwnerAndDigest(params FetchParams) ([]*Attestation, error) { - url := fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, params.Owner, params.Digest) - return c.getByURL(url, params) +// GetByDigest fetches the attestation by digest and either owner or repo +// depending on which is provided +func (c *LiveClient) GetByDigest(params FetchParams) ([]*Attestation, error) { + if params.Repo == "" && params.Owner == "" { + return nil, fmt.Errorf("owner or repo must be provided") + } else if params.Repo != "" { + // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. + // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. + url := fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, params.Repo, params.Digest) + attestations, err := c.getByURL(url, params) + if err != nil { + return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err) + } + return attestations, nil + } else { + url := fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, params.Owner, params.Digest) + attestations, err := c.getByURL(url, params) + if err != nil { + return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err) + } + return attestations, nil + } } func (c *LiveClient) getByURL(url string, params FetchParams) ([]*Attestation, error) { diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index f77b39d08..fb2e36b4d 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -51,7 +51,7 @@ var testFetchParams = FetchParams{ func TestGetByDigest(t *testing.T) { c := NewClientWithMockGHClient(false) testFetchParams.Repo = testRepo - attestations, err := c.GetByRepoAndDigest(testFetchParams) + attestations, err := c.GetByDigest(testFetchParams) require.NoError(t, err) require.Equal(t, 5, len(attestations)) @@ -59,7 +59,7 @@ func TestGetByDigest(t *testing.T) { require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") testFetchParams.Owner = testOwner - attestations, err = c.GetByOwnerAndDigest(testFetchParams) + attestations, err = c.GetByDigest(testFetchParams) require.NoError(t, err) require.Equal(t, 5, len(attestations)) @@ -74,7 +74,7 @@ func TestGetByDigestGreaterThanLimit(t *testing.T) { // The method should return five results when the limit is not set testFetchParams.Limit = limit testFetchParams.Repo = testRepo - attestations, err := c.GetByRepoAndDigest(testFetchParams) + attestations, err := c.GetByDigest(testFetchParams) require.NoError(t, err) require.Equal(t, 3, len(attestations)) @@ -82,7 +82,7 @@ func TestGetByDigestGreaterThanLimit(t *testing.T) { require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") testFetchParams.Owner = testOwner - attestations, err = c.GetByOwnerAndDigest(testFetchParams) + attestations, err = c.GetByDigest(testFetchParams) require.NoError(t, err) require.Equal(t, len(attestations), limit) @@ -94,7 +94,7 @@ func TestGetByDigestWithNextPage(t *testing.T) { c := NewClientWithMockGHClient(true) testFetchParams.Repo = testRepo testFetchParams.Limit = 30 - attestations, err := c.GetByRepoAndDigest(testFetchParams) + attestations, err := c.GetByDigest(testFetchParams) require.NoError(t, err) require.Equal(t, len(attestations), 10) @@ -102,7 +102,7 @@ func TestGetByDigestWithNextPage(t *testing.T) { require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") testFetchParams.Owner = testOwner - attestations, err = c.GetByOwnerAndDigest(testFetchParams) + attestations, err = c.GetByDigest(testFetchParams) require.NoError(t, err) require.Equal(t, len(attestations), 10) @@ -117,7 +117,7 @@ func TestGetByDigestGreaterThanLimitWithNextPage(t *testing.T) { // The method should return five results when the limit is not set testFetchParams.Limit = limit testFetchParams.Repo = testRepo - attestations, err := c.GetByRepoAndDigest(testFetchParams) + attestations, err := c.GetByDigest(testFetchParams) require.NoError(t, err) require.Equal(t, len(attestations), limit) @@ -125,7 +125,7 @@ func TestGetByDigestGreaterThanLimitWithNextPage(t *testing.T) { require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") testFetchParams.Owner = testOwner - attestations, err = c.GetByOwnerAndDigest(testFetchParams) + attestations, err = c.GetByDigest(testFetchParams) require.NoError(t, err) require.Equal(t, len(attestations), limit) @@ -148,13 +148,13 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) { } testFetchParams.Repo = testRepo - attestations, err := c.GetByRepoAndDigest(testFetchParams) + attestations, err := c.GetByDigest(testFetchParams) require.Error(t, err) require.IsType(t, ErrNoAttestationsFound, err) require.Nil(t, attestations) testFetchParams.Owner = testOwner - attestations, err = c.GetByOwnerAndDigest(testFetchParams) + attestations, err = c.GetByDigest(testFetchParams) require.Error(t, err) require.IsType(t, ErrNoAttestationsFound, err) require.Nil(t, attestations) @@ -173,12 +173,12 @@ func TestGetByDigest_Error(t *testing.T) { } testFetchParams.Repo = testRepo - attestations, err := c.GetByRepoAndDigest(testFetchParams) + attestations, err := c.GetByDigest(testFetchParams) require.Error(t, err) require.Nil(t, attestations) testFetchParams.Owner = testOwner - attestations, err = c.GetByOwnerAndDigest(testFetchParams) + attestations, err = c.GetByDigest(testFetchParams) require.Error(t, err) require.Nil(t, attestations) } @@ -385,7 +385,7 @@ func TestGetAttestationsRetries(t *testing.T) { testFetchParams.Repo = testRepo testFetchParams.Limit = 30 - attestations, err := c.GetByRepoAndDigest(testFetchParams) + attestations, err := c.GetByDigest(testFetchParams) require.NoError(t, err) // assert the error path was executed; because this is a paged @@ -397,9 +397,9 @@ func TestGetAttestationsRetries(t *testing.T) { bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - // same test as above, but for GetByOwnerAndDigest: + // same test as above, but for GetByDigest: testFetchParams.Owner = testOwner - attestations, err = c.GetByOwnerAndDigest(testFetchParams) + attestations, err = c.GetByDigest(testFetchParams) require.NoError(t, err) // because we haven't reset the mock, we have added 2 more failed requests @@ -426,7 +426,7 @@ func TestGetAttestationsMaxRetries(t *testing.T) { } testFetchParams.Repo = testRepo - _, err := c.GetByRepoAndDigest(testFetchParams) + _, err := c.GetByDigest(testFetchParams) require.Error(t, err) fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4) diff --git a/pkg/cmd/attestation/api/mock_client.go b/pkg/cmd/attestation/api/mock_client.go index be2b9b76d..efcedc8b5 100644 --- a/pkg/cmd/attestation/api/mock_client.go +++ b/pkg/cmd/attestation/api/mock_client.go @@ -6,58 +6,44 @@ import ( "github.com/cli/cli/v2/pkg/cmd/attestation/test/data" ) +func makeTestAttestation() Attestation { + return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"} +} + type MockClient struct { - OnGetByRepoAndDigest func(params FetchParams) ([]*Attestation, error) - OnGetByOwnerAndDigest func(params FetchParams) ([]*Attestation, error) - OnGetTrustDomain func() (string, error) + OnGetByDigest func(params FetchParams) ([]*Attestation, error) + OnGetTrustDomain func() (string, error) } -func (m MockClient) GetByRepoAndDigest(params FetchParams) ([]*Attestation, error) { - return m.OnGetByRepoAndDigest(params) -} - -func (m MockClient) GetByOwnerAndDigest(params FetchParams) ([]*Attestation, error) { - return m.OnGetByOwnerAndDigest(params) +func (m MockClient) GetByDigest(params FetchParams) ([]*Attestation, error) { + return m.OnGetByDigest(params) } func (m MockClient) GetTrustDomain() (string, error) { return m.OnGetTrustDomain() } -func makeTestAttestation() Attestation { - return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"} -} - -func OnGetByRepoAndDigestSuccess(params FetchParams) ([]*Attestation, error) { +func OnGetByDigestSuccess(params FetchParams) ([]*Attestation, error) { att1 := makeTestAttestation() att2 := makeTestAttestation() return []*Attestation{&att1, &att2}, nil } -func OnGetByRepoAndDigestFailure(params FetchParams) ([]*Attestation, error) { - return nil, fmt.Errorf("failed to fetch by repo and digest") -} - -func OnGetByOwnerAndDigestSuccess(params FetchParams) ([]*Attestation, error) { - att1 := makeTestAttestation() - att2 := makeTestAttestation() - return []*Attestation{&att1, &att2}, nil -} - -func OnGetByOwnerAndDigestFailure(params FetchParams) ([]*Attestation, error) { +func OnGetByDigestFailure(params FetchParams) ([]*Attestation, error) { + if params.Repo != "" { + return nil, fmt.Errorf("failed to fetch by repo and digest") + } return nil, fmt.Errorf("failed to fetch by owner and digest") } func NewTestClient() *MockClient { return &MockClient{ - OnGetByRepoAndDigest: OnGetByRepoAndDigestSuccess, - OnGetByOwnerAndDigest: OnGetByOwnerAndDigestSuccess, + OnGetByDigest: OnGetByDigestSuccess, } } func NewFailTestClient() *MockClient { return &MockClient{ - OnGetByRepoAndDigest: OnGetByRepoAndDigestFailure, - OnGetByOwnerAndDigest: OnGetByOwnerAndDigestFailure, + OnGetByDigest: OnGetByDigestFailure, } } diff --git a/pkg/cmd/attestation/download/download.go b/pkg/cmd/attestation/download/download.go index 65b6f83df..2a297bc9a 100644 --- a/pkg/cmd/attestation/download/download.go +++ b/pkg/cmd/attestation/download/download.go @@ -133,7 +133,7 @@ func runDownload(opts *Options) error { Owner: opts.Owner, Repo: opts.Repo, } - attestations, err := verification.GetRemoteAttestations(opts.APIClient, params) + attestations, err := opts.APIClient.GetByDigest(params) if err != nil { if errors.Is(err, api.ErrNoAttestationsFound) { fmt.Fprintf(opts.Logger.IO.Out, "No attestations found for %s\n", opts.ArtifactPath) diff --git a/pkg/cmd/attestation/download/download_test.go b/pkg/cmd/attestation/download/download_test.go index 629de4a66..11872daf9 100644 --- a/pkg/cmd/attestation/download/download_test.go +++ b/pkg/cmd/attestation/download/download_test.go @@ -275,7 +275,7 @@ func TestRunDownload(t *testing.T) { t.Run("no attestations found", func(t *testing.T) { opts := baseOpts opts.APIClient = api.MockClient{ - OnGetByOwnerAndDigest: func(params api.FetchParams) ([]*api.Attestation, error) { + OnGetByDigest: func(params api.FetchParams) ([]*api.Attestation, error) { return nil, api.ErrNoAttestationsFound }, } @@ -291,7 +291,7 @@ func TestRunDownload(t *testing.T) { t.Run("failed to fetch attestations", func(t *testing.T) { opts := baseOpts opts.APIClient = api.MockClient{ - OnGetByOwnerAndDigest: func(params api.FetchParams) ([]*api.Attestation, error) { + OnGetByDigest: func(params api.FetchParams) ([]*api.Attestation, error) { return nil, fmt.Errorf("failed to fetch attestations") }, } diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index 33d8b18b8..9757b72c5 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -82,31 +82,6 @@ func loadBundlesFromJSONLinesFile(path string) ([]*api.Attestation, error) { return attestations, nil } -func GetRemoteAttestations(client api.Client, params api.FetchParams) ([]*api.Attestation, error) { - if client == nil { - return nil, fmt.Errorf("api client must be provided") - } - // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. - // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. - var attestations []*api.Attestation - var err error - var owner string - if params.Repo != "" { - attestations, err = client.GetByRepoAndDigest(params) - owner = params.Repo - } else if params.Owner != "" { - attestations, err = client.GetByOwnerAndDigest(params) - owner = params.Owner - } else { - return nil, fmt.Errorf("owner or repo must be provided") - } - - if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", owner, err) - } - return attestations, err -} - func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ([]*api.Attestation, error) { attestations, err := client.GetAttestations(artifact.NameRef(), artifact.DigestWithAlg()) if err != nil { diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index f956e82b8..f91526bbd 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -29,7 +29,7 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio Repo: o.Repo, } - attestations, err := verification.GetRemoteAttestations(o.APIClient, params) + attestations, err := o.APIClient.GetByDigest(params) if err != nil { msg := "✗ Loading attestations from GitHub API failed" return nil, msg, err From a856a796f0e0ed801155933f4ff7daf07cef0604 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Mon, 24 Mar 2025 18:34:54 -0600 Subject: [PATCH 07/51] remove duplicate predicate filtering code Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/download/download.go | 7 +++---- pkg/cmd/attestation/verification/attestation.go | 8 ++++++-- .../attestation/verification/attestation_test.go | 7 ++++--- pkg/cmd/attestation/verify/attestation.go | 14 ++------------ 4 files changed, 15 insertions(+), 21 deletions(-) diff --git a/pkg/cmd/attestation/download/download.go b/pkg/cmd/attestation/download/download.go index 2a297bc9a..86cf08d72 100644 --- a/pkg/cmd/attestation/download/download.go +++ b/pkg/cmd/attestation/download/download.go @@ -144,10 +144,9 @@ func runDownload(opts *Options) error { // Apply predicate type filter to returned attestations if opts.PredicateType != "" { - filteredAttestations := verification.FilterAttestations(opts.PredicateType, attestations) - - if len(filteredAttestations) == 0 { - return fmt.Errorf("no attestations found with predicate type: %s", opts.PredicateType) + filteredAttestations, err := verification.FilterAttestations(opts.PredicateType, attestations) + if err != nil { + return fmt.Errorf("failed to filter attestations: %v", err) } attestations = filteredAttestations diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index 9757b72c5..ba357a5cc 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -97,7 +97,7 @@ type IntotoStatement struct { PredicateType string `json:"predicateType"` } -func FilterAttestations(predicateType string, attestations []*api.Attestation) []*api.Attestation { +func FilterAttestations(predicateType string, attestations []*api.Attestation) ([]*api.Attestation, error) { filteredAttestations := []*api.Attestation{} for _, each := range attestations { @@ -118,5 +118,9 @@ func FilterAttestations(predicateType string, attestations []*api.Attestation) [ } } - return filteredAttestations + if len(filteredAttestations) == 0 { + return nil, fmt.Errorf("no attestations found with predicate type: %s", predicateType) + } + + return filteredAttestations, nil } diff --git a/pkg/cmd/attestation/verification/attestation_test.go b/pkg/cmd/attestation/verification/attestation_test.go index 8acff0c37..55a447cf4 100644 --- a/pkg/cmd/attestation/verification/attestation_test.go +++ b/pkg/cmd/attestation/verification/attestation_test.go @@ -157,10 +157,11 @@ func TestFilterAttestations(t *testing.T) { }, } - filtered := FilterAttestations("https://slsa.dev/provenance/v1", attestations) - + filtered, err := FilterAttestations("https://slsa.dev/provenance/v1", attestations) require.Len(t, filtered, 1) + require.NoError(t, err) - filtered = FilterAttestations("NonExistentPredicate", attestations) + filtered, err = FilterAttestations("NonExistentPredicate", attestations) require.Len(t, filtered, 0) + require.NoError(t, err) } diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index f91526bbd..c09b433b0 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -9,16 +9,6 @@ import ( "github.com/cli/cli/v2/pkg/cmd/attestation/verification" ) -func filterByPredicateType(predicateType string, attestations []*api.Attestation) ([]*api.Attestation, string, error) { - // Apply predicate type filter to returned attestations - filteredAttestations := verification.FilterAttestations(predicateType, attestations) - if len(filteredAttestations) == 0 { - msg := fmt.Sprintf("✗ No attestations found with predicate type: %s\n", predicateType) - return nil, msg, fmt.Errorf("no matching predicate found") - } - return filteredAttestations, "", nil -} - func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) { if o.FetchAttestationsFromGitHubAPI() { params := api.FetchParams{ @@ -58,9 +48,9 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio return nil, errMsg, err } - filtered, errMsg, err := filterByPredicateType(o.PredicateType, attestations) + filtered, err := verification.FilterAttestations(o.PredicateType, attestations) if err != nil { - return nil, errMsg, err + return nil, err.Error(), err } pluralAttestation := text.Pluralize(len(filtered), "attestation") From 0d0654738b7bf1cc1f200dd086c284d0a25f5be0 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Mon, 24 Mar 2025 18:58:35 -0600 Subject: [PATCH 08/51] simplify client methods Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/api/client.go | 66 +++++++++++++++++-------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/pkg/cmd/attestation/api/client.go b/pkg/cmd/attestation/api/client.go index a3e627852..d0cffcb27 100644 --- a/pkg/cmd/attestation/api/client.go +++ b/pkg/cmd/attestation/api/client.go @@ -36,6 +36,19 @@ type FetchParams struct { Repo string } +func (p *FetchParams) Validate() error { + if p.Digest == "" { + return fmt.Errorf("digest must be provided") + } + if p.Limit <= 0 || p.Limit > maxLimitForFlag { + return fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag) + } + if p.Repo == "" && p.Owner == "" { + return fmt.Errorf("owner or repo must be provided") + } + return nil +} + // githubApiClient makes REST calls to the GitHub API type githubApiClient interface { REST(hostname, method, p string, body io.Reader, data interface{}) error @@ -71,30 +84,8 @@ func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClien // GetByDigest fetches the attestation by digest and either owner or repo // depending on which is provided func (c *LiveClient) GetByDigest(params FetchParams) ([]*Attestation, error) { - if params.Repo == "" && params.Owner == "" { - return nil, fmt.Errorf("owner or repo must be provided") - } else if params.Repo != "" { - // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. - // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. - url := fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, params.Repo, params.Digest) - attestations, err := c.getByURL(url, params) - if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err) - } - return attestations, nil - } else { - url := fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, params.Owner, params.Digest) - attestations, err := c.getByURL(url, params) - if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err) - } - return attestations, nil - } -} - -func (c *LiveClient) getByURL(url string, params FetchParams) ([]*Attestation, error) { c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", params.Digest) - attestations, err := c.getAttestations(url, params) + attestations, err := c.getAttestations(params) if err != nil { return nil, err } @@ -107,13 +98,24 @@ func (c *LiveClient) getByURL(url string, params FetchParams) ([]*Attestation, e return bundles, nil } -// GetTrustDomain returns the current trust domain. If the default is used -// the empty string is returned -func (c *LiveClient) GetTrustDomain() (string, error) { - return c.getTrustDomain(MetaPath) -} +func (c *LiveClient) getAttestations(params FetchParams) ([]*Attestation, error) { + if err := params.Validate(); err != nil { + return nil, err + } + + var urlTemplate string + var resourceOwner string + if params.Repo != "" { + // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. + // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. + urlTemplate = GetAttestationByRepoAndSubjectDigestPath + resourceOwner = params.Repo + } else { + urlTemplate = GetAttestationByOwnerAndSubjectDigestPath + resourceOwner = params.Owner + } + url := fmt.Sprintf(urlTemplate, resourceOwner, params.Digest) -func (c *LiveClient) getAttestations(url string, params FetchParams) ([]*Attestation, error) { perPage := params.Limit if perPage <= 0 || perPage > maxLimitForFlag { return nil, fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag) @@ -263,6 +265,12 @@ func shouldRetry(err error) bool { return false } +// GetTrustDomain returns the current trust domain. If the default is used +// the empty string is returned +func (c *LiveClient) GetTrustDomain() (string, error) { + return c.getTrustDomain(MetaPath) +} + func (c *LiveClient) getTrustDomain(url string) (string, error) { var resp MetaResponse From baeaf66011464ea49954dcf0ec25574b8fa7070c Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Mon, 24 Mar 2025 19:13:27 -0600 Subject: [PATCH 09/51] restructure api client methods Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/api/client.go | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/pkg/cmd/attestation/api/client.go b/pkg/cmd/attestation/api/client.go index d0cffcb27..61d0bee52 100644 --- a/pkg/cmd/attestation/api/client.go +++ b/pkg/cmd/attestation/api/client.go @@ -98,29 +98,21 @@ func (c *LiveClient) GetByDigest(params FetchParams) ([]*Attestation, error) { return bundles, nil } -func (c *LiveClient) getAttestations(params FetchParams) ([]*Attestation, error) { +func (c *LiveClient) buildRequestURL(params FetchParams) (string, error) { if err := params.Validate(); err != nil { - return nil, err + return "", err } - var urlTemplate string - var resourceOwner string + var url string if params.Repo != "" { // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. - urlTemplate = GetAttestationByRepoAndSubjectDigestPath - resourceOwner = params.Repo + url = fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, params.Repo, params.Digest) } else { - urlTemplate = GetAttestationByOwnerAndSubjectDigestPath - resourceOwner = params.Owner + url = fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, params.Owner, params.Digest) } - url := fmt.Sprintf(urlTemplate, resourceOwner, params.Digest) perPage := params.Limit - if perPage <= 0 || perPage > maxLimitForFlag { - return nil, fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag) - } - if perPage > maxLimitForFetch { perPage = maxLimitForFetch } @@ -130,6 +122,14 @@ func (c *LiveClient) getAttestations(params FetchParams) ([]*Attestation, error) if params.PredicateType != "" { url = fmt.Sprintf("%s&predicate_type=%s", url, params.PredicateType) } + return url, nil +} + +func (c *LiveClient) getAttestations(params FetchParams) ([]*Attestation, error) { + url, err := c.buildRequestURL(params) + if err != nil { + return nil, err + } var attestations []*Attestation var resp AttestationsResponse @@ -139,13 +139,11 @@ func (c *LiveClient) getAttestations(params FetchParams) ([]*Attestation, error) for url != "" && len(attestations) < params.Limit { err := backoff.Retry(func() error { newURL, restErr := c.githubAPI.RESTWithNext(c.host, http.MethodGet, url, nil, &resp) - if restErr != nil { if shouldRetry(restErr) { return restErr - } else { - return backoff.Permanent(restErr) } + return backoff.Permanent(restErr) } url = newURL From d1c4bf7dd9f02a50c851c74df7dd9ae8cb759609 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Tue, 25 Mar 2025 08:24:52 -0600 Subject: [PATCH 10/51] comment Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/verify/attestation.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index c09b433b0..6dd855bbc 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -10,6 +10,8 @@ import ( ) func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) { + // Fetch attestations from GitHub API within this if block since predicate type + // filter is done when the API is called if o.FetchAttestationsFromGitHubAPI() { params := api.FetchParams{ Digest: a.DigestWithAlg(), @@ -29,6 +31,8 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio return attestations, msg, nil } + // Fetch attestations from local bundle or OCI registry + // Predicate type filtering is done after the attestations are fetched var attestations []*api.Attestation var err error var errMsg string From e3fbe9008f8c0caf91b88ecbb3581bc96f7484ba Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Tue, 25 Mar 2025 08:25:00 -0600 Subject: [PATCH 11/51] reduce test duplication Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/api/client_test.go | 186 ++++++++++--------------- 1 file changed, 73 insertions(+), 113 deletions(-) diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index fb2e36b4d..0b7adfcf4 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -42,95 +42,82 @@ func NewClientWithMockGHClient(hasNextPage bool) Client { } } -var testFetchParams = FetchParams{ +var testFetchParamsWithOwner = FetchParams{ Digest: testDigest, Limit: DefaultLimit, + Owner: testOwner, + PredicateType: "https://slsa.dev/provenance/v1", +} +var testFetchParamsWithRepo = FetchParams{ + Digest: testDigest, + Limit: DefaultLimit, + Repo: testRepo, PredicateType: "https://slsa.dev/provenance/v1", } +type getByTestCase struct { + name string + params FetchParams + limit int + expectedAttestations int + expectedError error + hasNextPage bool +} + +var getByTestCases = []getByTestCase{ + { + name: "get by digest with owner", + params: testFetchParamsWithOwner, + expectedAttestations: 5, + expectedError: nil, + }, + { + name: "get by digest with repo", + params: testFetchParamsWithRepo, + expectedAttestations: 5, + expectedError: nil, + }, + { + name: "get by digest with attestations greater than limit", + params: testFetchParamsWithRepo, + limit: 3, + expectedAttestations: 3, + expectedError: nil, + }, + { + name: "get by digest with next page", + params: testFetchParamsWithRepo, + limit: 30, + expectedAttestations: 10, + expectedError: nil, + hasNextPage: true, + }, + { + name: "greater than limit with next page", + params: testFetchParamsWithRepo, + limit: 7, + expectedAttestations: 7, + expectedError: nil, + hasNextPage: true, + }, +} + func TestGetByDigest(t *testing.T) { - c := NewClientWithMockGHClient(false) - testFetchParams.Repo = testRepo - attestations, err := c.GetByDigest(testFetchParams) - require.NoError(t, err) + for _, tc := range getByTestCases { + t.Run(tc.name, func(t *testing.T) { + c := NewClientWithMockGHClient(tc.hasNextPage) - require.Equal(t, 5, len(attestations)) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") + if tc.limit > 0 { + tc.params.Limit = tc.limit + } + attestations, err := c.GetByDigest(tc.params) + require.NoError(t, err) - testFetchParams.Owner = testOwner - attestations, err = c.GetByDigest(testFetchParams) - require.NoError(t, err) - - require.Equal(t, 5, len(attestations)) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") -} - -func TestGetByDigestGreaterThanLimit(t *testing.T) { - c := NewClientWithMockGHClient(false) - - limit := 3 - // The method should return five results when the limit is not set - testFetchParams.Limit = limit - testFetchParams.Repo = testRepo - attestations, err := c.GetByDigest(testFetchParams) - require.NoError(t, err) - - require.Equal(t, 3, len(attestations)) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - testFetchParams.Owner = testOwner - attestations, err = c.GetByDigest(testFetchParams) - require.NoError(t, err) - - require.Equal(t, len(attestations), limit) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") -} - -func TestGetByDigestWithNextPage(t *testing.T) { - c := NewClientWithMockGHClient(true) - testFetchParams.Repo = testRepo - testFetchParams.Limit = 30 - attestations, err := c.GetByDigest(testFetchParams) - require.NoError(t, err) - - require.Equal(t, len(attestations), 10) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - testFetchParams.Owner = testOwner - attestations, err = c.GetByDigest(testFetchParams) - require.NoError(t, err) - - require.Equal(t, len(attestations), 10) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") -} - -func TestGetByDigestGreaterThanLimitWithNextPage(t *testing.T) { - c := NewClientWithMockGHClient(true) - - limit := 7 - // The method should return five results when the limit is not set - testFetchParams.Limit = limit - testFetchParams.Repo = testRepo - attestations, err := c.GetByDigest(testFetchParams) - require.NoError(t, err) - - require.Equal(t, len(attestations), limit) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - testFetchParams.Owner = testOwner - attestations, err = c.GetByDigest(testFetchParams) - require.NoError(t, err) - - require.Equal(t, len(attestations), limit) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") + require.Equal(t, tc.expectedAttestations, len(attestations)) + bundle := (attestations)[0].Bundle + require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") + }) + } } func TestGetByDigest_NoAttestationsFound(t *testing.T) { @@ -147,14 +134,7 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) { logger: io.NewTestHandler(), } - testFetchParams.Repo = testRepo - attestations, err := c.GetByDigest(testFetchParams) - require.Error(t, err) - require.IsType(t, ErrNoAttestationsFound, err) - require.Nil(t, attestations) - - testFetchParams.Owner = testOwner - attestations, err = c.GetByDigest(testFetchParams) + attestations, err := c.GetByDigest(testFetchParamsWithRepo) require.Error(t, err) require.IsType(t, ErrNoAttestationsFound, err) require.Nil(t, attestations) @@ -172,13 +152,7 @@ func TestGetByDigest_Error(t *testing.T) { logger: io.NewTestHandler(), } - testFetchParams.Repo = testRepo - attestations, err := c.GetByDigest(testFetchParams) - require.Error(t, err) - require.Nil(t, attestations) - - testFetchParams.Owner = testOwner - attestations, err = c.GetByDigest(testFetchParams) + attestations, err := c.GetByDigest(testFetchParamsWithRepo) require.Error(t, err) require.Nil(t, attestations) } @@ -383,9 +357,8 @@ func TestGetAttestationsRetries(t *testing.T) { logger: io.NewTestHandler(), } - testFetchParams.Repo = testRepo - testFetchParams.Limit = 30 - attestations, err := c.GetByDigest(testFetchParams) + testFetchParamsWithRepo.Limit = 30 + attestations, err := c.GetByDigest(testFetchParamsWithRepo) require.NoError(t, err) // assert the error path was executed; because this is a paged @@ -396,18 +369,6 @@ func TestGetAttestationsRetries(t *testing.T) { require.Equal(t, len(attestations), 10) bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - // same test as above, but for GetByDigest: - testFetchParams.Owner = testOwner - attestations, err = c.GetByDigest(testFetchParams) - require.NoError(t, err) - - // because we haven't reset the mock, we have added 2 more failed requests - fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 4) - - require.Equal(t, len(attestations), 10) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") } // test total retries @@ -425,8 +386,7 @@ func TestGetAttestationsMaxRetries(t *testing.T) { logger: io.NewTestHandler(), } - testFetchParams.Repo = testRepo - _, err := c.GetByDigest(testFetchParams) + _, err := c.GetByDigest(testFetchParamsWithRepo) require.Error(t, err) fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4) From 166e211e2bab796bd72594350f35956760ce14f3 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Tue, 25 Mar 2025 08:28:33 -0600 Subject: [PATCH 12/51] clean up test fixtures Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/api/client_test.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index 0b7adfcf4..384c7c9c8 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -60,7 +60,6 @@ type getByTestCase struct { params FetchParams limit int expectedAttestations int - expectedError error hasNextPage bool } @@ -69,27 +68,22 @@ var getByTestCases = []getByTestCase{ name: "get by digest with owner", params: testFetchParamsWithOwner, expectedAttestations: 5, - expectedError: nil, }, { name: "get by digest with repo", params: testFetchParamsWithRepo, expectedAttestations: 5, - expectedError: nil, }, { name: "get by digest with attestations greater than limit", params: testFetchParamsWithRepo, limit: 3, expectedAttestations: 3, - expectedError: nil, }, { name: "get by digest with next page", params: testFetchParamsWithRepo, - limit: 30, expectedAttestations: 10, - expectedError: nil, hasNextPage: true, }, { @@ -97,7 +91,6 @@ var getByTestCases = []getByTestCase{ params: testFetchParamsWithRepo, limit: 7, expectedAttestations: 7, - expectedError: nil, hasNextPage: true, }, } From 05d9156a992b7df3b5d7fd85020f8da7d2bea863 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Tue, 1 Apr 2025 11:16:00 -0600 Subject: [PATCH 13/51] add check for nil api client Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/download/download.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/cmd/attestation/download/download.go b/pkg/cmd/attestation/download/download.go index 86cf08d72..2cb648414 100644 --- a/pkg/cmd/attestation/download/download.go +++ b/pkg/cmd/attestation/download/download.go @@ -127,6 +127,9 @@ func runDownload(opts *Options) error { opts.Logger.VerbosePrintf("Downloading trusted metadata for artifact %s\n\n", opts.ArtifactPath) + if opts.APIClient == nil { + return fmt.Errorf("no APIClient provided") + } params := api.FetchParams{ Digest: artifact.DigestWithAlg(), Limit: opts.Limit, From 13dafefcb5ddb5eaa1e634b51338256f7ac588ee Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Tue, 1 Apr 2025 11:23:25 -0600 Subject: [PATCH 14/51] add missing nil struct checks and udpate error messages Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/api/mock_client.go | 4 ++-- pkg/cmd/attestation/verification/attestation_test.go | 4 ++-- pkg/cmd/attestation/verify/attestation.go | 6 ++++++ pkg/cmd/attestation/verify/verify_test.go | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/pkg/cmd/attestation/api/mock_client.go b/pkg/cmd/attestation/api/mock_client.go index efcedc8b5..fbf44bbb9 100644 --- a/pkg/cmd/attestation/api/mock_client.go +++ b/pkg/cmd/attestation/api/mock_client.go @@ -31,9 +31,9 @@ func OnGetByDigestSuccess(params FetchParams) ([]*Attestation, error) { func OnGetByDigestFailure(params FetchParams) ([]*Attestation, error) { if params.Repo != "" { - return nil, fmt.Errorf("failed to fetch by repo and digest") + return nil, fmt.Errorf("failed to fetch attestations from %s", params.Repo) } - return nil, fmt.Errorf("failed to fetch by owner and digest") + return nil, fmt.Errorf("failed to fetch attestations from %s", params.Owner) } func NewTestClient() *MockClient { diff --git a/pkg/cmd/attestation/verification/attestation_test.go b/pkg/cmd/attestation/verification/attestation_test.go index 55a447cf4..18e2c6cca 100644 --- a/pkg/cmd/attestation/verification/attestation_test.go +++ b/pkg/cmd/attestation/verification/attestation_test.go @@ -162,6 +162,6 @@ func TestFilterAttestations(t *testing.T) { require.NoError(t, err) filtered, err = FilterAttestations("NonExistentPredicate", attestations) - require.Len(t, filtered, 0) - require.NoError(t, err) + require.Nil(t, filtered) + require.Error(t, err) } diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index 6dd855bbc..2a935a56c 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -1,6 +1,7 @@ package verify import ( + "errors" "fmt" "github.com/cli/cli/v2/internal/text" @@ -13,6 +14,11 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio // Fetch attestations from GitHub API within this if block since predicate type // filter is done when the API is called if o.FetchAttestationsFromGitHubAPI() { + if o.APIClient == nil { + errMsg := "✗ No APIClient provided" + return nil, errMsg, errors.New(errMsg) + } + params := api.FetchParams{ Digest: a.DigestWithAlg(), Limit: o.Limit, diff --git a/pkg/cmd/attestation/verify/verify_test.go b/pkg/cmd/attestation/verify/verify_test.go index 092a009d8..2b821a435 100644 --- a/pkg/cmd/attestation/verify/verify_test.go +++ b/pkg/cmd/attestation/verify/verify_test.go @@ -510,7 +510,7 @@ func TestRunVerify(t *testing.T) { err := runVerify(&customOpts) require.Error(t, err) - require.ErrorContains(t, err, "no matching predicate found") + require.ErrorContains(t, err, "no attestations found with predicate type") }) t.Run("with valid OCI artifact with UseBundleFromRegistry flag but no bundle return from registry", func(t *testing.T) { From f43ec0079bab00cbc2065da7468c6867511818c9 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Tue, 1 Apr 2025 11:52:13 -0600 Subject: [PATCH 15/51] add test for predicate type filtering Signed-off-by: Meredith Lancaster --- .../verify/verify-with-internal-github-sigstore.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh b/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh index 647a13a4c..cea3c7228 100644 --- a/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh +++ b/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh @@ -14,3 +14,9 @@ if ! $ghBuildPath attestation verify "$ghCLIArtifact" --digest-alg=sha256 --owne echo "Failed to verify" exit 1 fi + +# Try to verify when specifying a predicate type that does not match the attestation +if $ghBuildPath attestation verify "$ghCLIArtifact" --digest-alg=sha256 --owner=cli --predicate-type=my-custom-predicate-type; then + echo "Verification should have failed" + exit 1 +fi From 56d924d25b39bb30497ba26c268d71428d794880 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Tue, 1 Apr 2025 12:58:37 -0600 Subject: [PATCH 16/51] getAttestations unit tests Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/verify/attestation.go | 18 +-- .../verify/attestation_integration_test.go | 119 ------------------ .../attestation/verify/attestation_test.go | 53 ++++++++ 3 files changed, 63 insertions(+), 127 deletions(-) delete mode 100644 pkg/cmd/attestation/verify/attestation_integration_test.go create mode 100644 pkg/cmd/attestation/verify/attestation_test.go diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index 2a935a56c..68d38f985 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -41,30 +41,32 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio // Predicate type filtering is done after the attestations are fetched var attestations []*api.Attestation var err error - var errMsg string + var msg string if o.BundlePath != "" { attestations, err = verification.GetLocalAttestations(o.BundlePath) if err != nil { - errMsg = fmt.Sprintf("✗ Loading attestations from %s failed", a.URL) + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg = fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) + } else { + msg = fmt.Sprintf("Loaded %d attestations from %s", len(attestations), o.BundlePath) } } else if o.UseBundleFromRegistry { attestations, err = verification.GetOCIAttestations(o.OCIClient, a) if err != nil { - errMsg = "✗ Loading attestations from OCI registry failed" + msg = "✗ Loading attestations from OCI registry failed" + } else { + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg = fmt.Sprintf("Loaded %s from OCI registry", pluralAttestation) } } - if err != nil { - return nil, errMsg, err + return nil, msg, err } filtered, err := verification.FilterAttestations(o.PredicateType, attestations) if err != nil { return nil, err.Error(), err } - - pluralAttestation := text.Pluralize(len(filtered), "attestation") - msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) return filtered, msg, nil } diff --git a/pkg/cmd/attestation/verify/attestation_integration_test.go b/pkg/cmd/attestation/verify/attestation_integration_test.go deleted file mode 100644 index 9ff174141..000000000 --- a/pkg/cmd/attestation/verify/attestation_integration_test.go +++ /dev/null @@ -1,119 +0,0 @@ -//go:build integration - -package verify - -import ( - "testing" - - "github.com/cli/cli/v2/pkg/cmd/attestation/api" - "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" - "github.com/cli/cli/v2/pkg/cmd/attestation/io" - "github.com/cli/cli/v2/pkg/cmd/attestation/test" - "github.com/cli/cli/v2/pkg/cmd/attestation/verification" - o "github.com/cli/cli/v2/pkg/option" - "github.com/sigstore/sigstore-go/pkg/fulcio/certificate" - "github.com/stretchr/testify/require" -) - -func getAttestationsFor(t *testing.T, bundlePath string) []*api.Attestation { - t.Helper() - - attestations, err := verification.GetLocalAttestations(bundlePath) - require.NoError(t, err) - - return attestations -} - -func TestVerifyAttestations(t *testing.T) { - sgVerifier := verification.NewLiveSigstoreVerifier(verification.SigstoreConfig{ - Logger: io.NewTestHandler(), - TUFMetadataDir: o.Some(t.TempDir()), - }) - - certSummary := certificate.Summary{} - certSummary.SourceRepositoryOwnerURI = "https://github.com/sigstore" - certSummary.SourceRepositoryURI = "https://github.com/sigstore/sigstore-js" - certSummary.Issuer = verification.GitHubOIDCIssuer - - ec := verification.EnforcementCriteria{ - Certificate: certSummary, - PredicateType: verification.SLSAPredicateV1, - SANRegex: "^https://github.com/sigstore/", - } - require.NoError(t, ec.Valid()) - - artifactPath := test.NormalizeRelativePath("../test/data/sigstore-js-2.1.0.tgz") - a, err := artifact.NewDigestedArtifact(nil, artifactPath, "sha512") - require.NoError(t, err) - - t.Run("all attestations pass verification", func(t *testing.T) { - attestations := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0_with_2_bundles.jsonl") - require.Len(t, attestations, 2) - results, errMsg, err := verifyAttestations(*a, attestations, sgVerifier, ec) - require.NoError(t, err) - require.Zero(t, errMsg) - require.Len(t, results, 2) - }) - - t.Run("passes verification with 2/3 attestations passing Sigstore verification", func(t *testing.T) { - invalidBundle := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0-bundle-v0.1.json") - attestations := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0_with_2_bundles.jsonl") - attestations = append(attestations, invalidBundle[0]) - require.Len(t, attestations, 3) - - results, errMsg, err := verifyAttestations(*a, attestations, sgVerifier, ec) - require.NoError(t, err) - require.Zero(t, errMsg) - require.Len(t, results, 2) - }) - - t.Run("fails verification when Sigstore verification fails", func(t *testing.T) { - invalidBundle := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0-bundle-v0.1.json") - invalidBundle2 := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0-bundle-v0.1.json") - attestations := append(invalidBundle, invalidBundle2...) - require.Len(t, attestations, 2) - - results, errMsg, err := verifyAttestations(*a, attestations, sgVerifier, ec) - require.Error(t, err) - require.Contains(t, errMsg, "✗ Sigstore verification failed") - require.Nil(t, results) - }) - - t.Run("attestations fail to verify when cert extensions don't match enforcement criteria", func(t *testing.T) { - sgjAttestation := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0_with_2_bundles.jsonl") - reusableWorkflowAttestations := getAttestationsFor(t, "../test/data/reusable-workflow-attestation.sigstore.json") - attestations := []*api.Attestation{sgjAttestation[0], reusableWorkflowAttestations[0], sgjAttestation[1]} - require.Len(t, attestations, 3) - - rwfResult := verification.BuildMockResult(reusableWorkflowAttestations[0].Bundle, "", "", "https://github.com/malancas", "", verification.GitHubOIDCIssuer) - sgjResult := verification.BuildSigstoreJsMockResult(t) - mockResults := []*verification.AttestationProcessingResult{&sgjResult, &rwfResult, &sgjResult} - mockSgVerifier := verification.NewMockSigstoreVerifierWithMockResults(t, mockResults) - - // we want to test that attestations that pass Sigstore verification but fail - // cert extension verification are filtered out properly in the second step - // in verifyAttestations. By using a mock Sigstore verifier, we can ensure - // that the call to verification.VerifyCertExtensions in verifyAttestations - // is filtering out attestations as expected - results, errMsg, err := verifyAttestations(*a, attestations, mockSgVerifier, ec) - require.NoError(t, err) - require.Zero(t, errMsg) - require.Len(t, results, 2) - for _, result := range results { - require.NotEqual(t, result.Attestation.Bundle, reusableWorkflowAttestations[0].Bundle) - } - }) - - t.Run("fails verification when cert extension verification fails", func(t *testing.T) { - attestations := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0_with_2_bundles.jsonl") - require.Len(t, attestations, 2) - - expectedCriteria := ec - expectedCriteria.Certificate.SourceRepositoryOwnerURI = "https://github.com/wrong" - - results, errMsg, err := verifyAttestations(*a, attestations, sgVerifier, expectedCriteria) - require.Error(t, err) - require.Contains(t, errMsg, "✗ Policy verification failed") - require.Nil(t, results) - }) -} diff --git a/pkg/cmd/attestation/verify/attestation_test.go b/pkg/cmd/attestation/verify/attestation_test.go new file mode 100644 index 000000000..93ac7d327 --- /dev/null +++ b/pkg/cmd/attestation/verify/attestation_test.go @@ -0,0 +1,53 @@ +package verify + +import ( + "testing" + + "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" + "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" + "github.com/cli/cli/v2/pkg/cmd/attestation/verification" + "github.com/stretchr/testify/require" +) + +func TestGetAttestations_OCIRegistry_PredicateTypeFiltering(t *testing.T) { + artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256") + require.NoError(t, err) + + o := &Options{ + OCIClient: oci.MockClient{}, + PredicateType: verification.SLSAPredicateV1, + Repo: "cli/cli", + UseBundleFromRegistry: true, + } + attestations, msg, err := getAttestations(o, *artifact) + require.NoError(t, err) + require.Contains(t, msg, "Loaded 2 attestations from OCI registry") + require.Len(t, attestations, 2) + + o.PredicateType = "custom predicate type" + attestations, msg, err = getAttestations(o, *artifact) + require.Error(t, err) + require.Contains(t, msg, "no attestations found with predicate type") + require.Nil(t, attestations) +} + +func TestGetAttestations_LocalBundle_PredicateTypeFiltering(t *testing.T) { + artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256") + require.NoError(t, err) + + o := &Options{ + BundlePath: "../test/data/sigstore-js-2.1.0-bundle.json", + PredicateType: verification.SLSAPredicateV1, + Repo: "sigstore/sigstore-js", + } + attestations, msg, err := getAttestations(o, *artifact) + require.NoError(t, err) + require.Contains(t, msg, "Loaded 1 attestation from ../test/data/gh_2.60.1_windows_arm64.zip") + require.Len(t, attestations, 1) + + o.PredicateType = "custom predicate type" + attestations, msg, err = getAttestations(o, *artifact) + require.Error(t, err) + require.Contains(t, msg, "no attestations found with predicate type") + require.Nil(t, attestations) +} From 164a56cb663702233a4bf4e884acf8d0578d6632 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Thu, 3 Apr 2025 11:02:45 -0600 Subject: [PATCH 17/51] move filterAttestations function Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/api/attestation.go | 35 +++++++++++++++++++ pkg/cmd/attestation/api/mock_client.go | 7 +++- pkg/cmd/attestation/download/download.go | 3 +- .../attestation/verification/attestation.go | 32 ----------------- .../verification/attestation_test.go | 4 +-- pkg/cmd/attestation/verify/attestation.go | 2 +- .../attestation/verify/attestation_test.go | 26 +++++++++++--- 7 files changed, 67 insertions(+), 42 deletions(-) diff --git a/pkg/cmd/attestation/api/attestation.go b/pkg/cmd/attestation/api/attestation.go index fd6b484a7..daec12b50 100644 --- a/pkg/cmd/attestation/api/attestation.go +++ b/pkg/cmd/attestation/api/attestation.go @@ -1,7 +1,10 @@ package api import ( + "encoding/json" "errors" + "fmt" + "github.com/sigstore/sigstore-go/pkg/bundle" ) @@ -20,3 +23,35 @@ type Attestation struct { type AttestationsResponse struct { Attestations []*Attestation `json:"attestations"` } + +type IntotoStatement struct { + PredicateType string `json:"predicateType"` +} + +func FilterAttestations(predicateType string, attestations []*Attestation) ([]*Attestation, error) { + filteredAttestations := []*Attestation{} + + for _, each := range attestations { + dsseEnvelope := each.Bundle.GetDsseEnvelope() + if dsseEnvelope != nil { + if dsseEnvelope.PayloadType != "application/vnd.in-toto+json" { + // Don't fail just because an entry isn't intoto + continue + } + var intotoStatement IntotoStatement + if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &intotoStatement); err != nil { + // Don't fail just because a single entry can't be unmarshalled + continue + } + if intotoStatement.PredicateType == predicateType { + filteredAttestations = append(filteredAttestations, each) + } + } + } + + if len(filteredAttestations) == 0 { + return nil, fmt.Errorf("no attestations found with predicate type: %s", predicateType) + } + + return filteredAttestations, nil +} diff --git a/pkg/cmd/attestation/api/mock_client.go b/pkg/cmd/attestation/api/mock_client.go index fbf44bbb9..b6062b39f 100644 --- a/pkg/cmd/attestation/api/mock_client.go +++ b/pkg/cmd/attestation/api/mock_client.go @@ -26,7 +26,12 @@ func (m MockClient) GetTrustDomain() (string, error) { func OnGetByDigestSuccess(params FetchParams) ([]*Attestation, error) { att1 := makeTestAttestation() att2 := makeTestAttestation() - return []*Attestation{&att1, &att2}, nil + attestations := []*Attestation{&att1, &att2} + if params.PredicateType != "" { + return FilterAttestations(params.PredicateType, attestations) + } + + return attestations, nil } func OnGetByDigestFailure(params FetchParams) ([]*Attestation, error) { diff --git a/pkg/cmd/attestation/download/download.go b/pkg/cmd/attestation/download/download.go index 2cb648414..8d1d1dc05 100644 --- a/pkg/cmd/attestation/download/download.go +++ b/pkg/cmd/attestation/download/download.go @@ -9,7 +9,6 @@ import ( "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" "github.com/cli/cli/v2/pkg/cmd/attestation/auth" "github.com/cli/cli/v2/pkg/cmd/attestation/io" - "github.com/cli/cli/v2/pkg/cmd/attestation/verification" "github.com/cli/cli/v2/pkg/cmdutil" ghauth "github.com/cli/go-gh/v2/pkg/auth" @@ -147,7 +146,7 @@ func runDownload(opts *Options) error { // Apply predicate type filter to returned attestations if opts.PredicateType != "" { - filteredAttestations, err := verification.FilterAttestations(opts.PredicateType, attestations) + filteredAttestations, err := api.FilterAttestations(opts.PredicateType, attestations) if err != nil { return fmt.Errorf("failed to filter attestations: %v", err) } diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index ba357a5cc..10eb02ac4 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -92,35 +92,3 @@ func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ( } return attestations, nil } - -type IntotoStatement struct { - PredicateType string `json:"predicateType"` -} - -func FilterAttestations(predicateType string, attestations []*api.Attestation) ([]*api.Attestation, error) { - filteredAttestations := []*api.Attestation{} - - for _, each := range attestations { - dsseEnvelope := each.Bundle.GetDsseEnvelope() - if dsseEnvelope != nil { - if dsseEnvelope.PayloadType != "application/vnd.in-toto+json" { - // Don't fail just because an entry isn't intoto - continue - } - var intotoStatement IntotoStatement - if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &intotoStatement); err != nil { - // Don't fail just because a single entry can't be unmarshalled - continue - } - if intotoStatement.PredicateType == predicateType { - filteredAttestations = append(filteredAttestations, each) - } - } - } - - if len(filteredAttestations) == 0 { - return nil, fmt.Errorf("no attestations found with predicate type: %s", predicateType) - } - - return filteredAttestations, nil -} diff --git a/pkg/cmd/attestation/verification/attestation_test.go b/pkg/cmd/attestation/verification/attestation_test.go index 18e2c6cca..6826e2e40 100644 --- a/pkg/cmd/attestation/verification/attestation_test.go +++ b/pkg/cmd/attestation/verification/attestation_test.go @@ -157,11 +157,11 @@ func TestFilterAttestations(t *testing.T) { }, } - filtered, err := FilterAttestations("https://slsa.dev/provenance/v1", attestations) + filtered, err := api.FilterAttestations("https://slsa.dev/provenance/v1", attestations) require.Len(t, filtered, 1) require.NoError(t, err) - filtered, err = FilterAttestations("NonExistentPredicate", attestations) + filtered, err = api.FilterAttestations("NonExistentPredicate", attestations) require.Nil(t, filtered) require.Error(t, err) } diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index 68d38f985..1b98fabf3 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -63,7 +63,7 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio return nil, msg, err } - filtered, err := verification.FilterAttestations(o.PredicateType, attestations) + filtered, err := api.FilterAttestations(o.PredicateType, attestations) if err != nil { return nil, err.Error(), err } diff --git a/pkg/cmd/attestation/verify/attestation_test.go b/pkg/cmd/attestation/verify/attestation_test.go index 93ac7d327..f015805ae 100644 --- a/pkg/cmd/attestation/verify/attestation_test.go +++ b/pkg/cmd/attestation/verify/attestation_test.go @@ -3,6 +3,7 @@ package verify import ( "testing" + "github.com/cli/cli/v2/pkg/cmd/attestation/api" "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" "github.com/cli/cli/v2/pkg/cmd/attestation/verification" @@ -40,14 +41,31 @@ func TestGetAttestations_LocalBundle_PredicateTypeFiltering(t *testing.T) { PredicateType: verification.SLSAPredicateV1, Repo: "sigstore/sigstore-js", } - attestations, msg, err := getAttestations(o, *artifact) + attestations, _, err := getAttestations(o, *artifact) require.NoError(t, err) - require.Contains(t, msg, "Loaded 1 attestation from ../test/data/gh_2.60.1_windows_arm64.zip") require.Len(t, attestations, 1) o.PredicateType = "custom predicate type" - attestations, msg, err = getAttestations(o, *artifact) + attestations, _, err = getAttestations(o, *artifact) + require.Error(t, err) + require.Nil(t, attestations) +} + +func TestGetAttestations_GhAPI_NoAttestationsFound(t *testing.T) { + artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256") + require.NoError(t, err) + + o := &Options{ + APIClient: api.NewTestClient(), + PredicateType: verification.SLSAPredicateV1, + Repo: "sigstore/sigstore-js", + } + attestations, _, err := getAttestations(o, *artifact) + require.NoError(t, err) + require.Len(t, attestations, 2) + + o.PredicateType = "custom predicate type" + attestations, _, err = getAttestations(o, *artifact) require.Error(t, err) - require.Contains(t, msg, "no attestations found with predicate type") require.Nil(t, attestations) } From 69507282d251ffde053baa28eb26e644ed4a3be5 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Thu, 3 Apr 2025 11:07:06 -0600 Subject: [PATCH 18/51] restore deleted file Signed-off-by: Meredith Lancaster --- .../verify/attestation_integration_test.go | 119 ++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 pkg/cmd/attestation/verify/attestation_integration_test.go diff --git a/pkg/cmd/attestation/verify/attestation_integration_test.go b/pkg/cmd/attestation/verify/attestation_integration_test.go new file mode 100644 index 000000000..9ff174141 --- /dev/null +++ b/pkg/cmd/attestation/verify/attestation_integration_test.go @@ -0,0 +1,119 @@ +//go:build integration + +package verify + +import ( + "testing" + + "github.com/cli/cli/v2/pkg/cmd/attestation/api" + "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" + "github.com/cli/cli/v2/pkg/cmd/attestation/io" + "github.com/cli/cli/v2/pkg/cmd/attestation/test" + "github.com/cli/cli/v2/pkg/cmd/attestation/verification" + o "github.com/cli/cli/v2/pkg/option" + "github.com/sigstore/sigstore-go/pkg/fulcio/certificate" + "github.com/stretchr/testify/require" +) + +func getAttestationsFor(t *testing.T, bundlePath string) []*api.Attestation { + t.Helper() + + attestations, err := verification.GetLocalAttestations(bundlePath) + require.NoError(t, err) + + return attestations +} + +func TestVerifyAttestations(t *testing.T) { + sgVerifier := verification.NewLiveSigstoreVerifier(verification.SigstoreConfig{ + Logger: io.NewTestHandler(), + TUFMetadataDir: o.Some(t.TempDir()), + }) + + certSummary := certificate.Summary{} + certSummary.SourceRepositoryOwnerURI = "https://github.com/sigstore" + certSummary.SourceRepositoryURI = "https://github.com/sigstore/sigstore-js" + certSummary.Issuer = verification.GitHubOIDCIssuer + + ec := verification.EnforcementCriteria{ + Certificate: certSummary, + PredicateType: verification.SLSAPredicateV1, + SANRegex: "^https://github.com/sigstore/", + } + require.NoError(t, ec.Valid()) + + artifactPath := test.NormalizeRelativePath("../test/data/sigstore-js-2.1.0.tgz") + a, err := artifact.NewDigestedArtifact(nil, artifactPath, "sha512") + require.NoError(t, err) + + t.Run("all attestations pass verification", func(t *testing.T) { + attestations := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0_with_2_bundles.jsonl") + require.Len(t, attestations, 2) + results, errMsg, err := verifyAttestations(*a, attestations, sgVerifier, ec) + require.NoError(t, err) + require.Zero(t, errMsg) + require.Len(t, results, 2) + }) + + t.Run("passes verification with 2/3 attestations passing Sigstore verification", func(t *testing.T) { + invalidBundle := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0-bundle-v0.1.json") + attestations := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0_with_2_bundles.jsonl") + attestations = append(attestations, invalidBundle[0]) + require.Len(t, attestations, 3) + + results, errMsg, err := verifyAttestations(*a, attestations, sgVerifier, ec) + require.NoError(t, err) + require.Zero(t, errMsg) + require.Len(t, results, 2) + }) + + t.Run("fails verification when Sigstore verification fails", func(t *testing.T) { + invalidBundle := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0-bundle-v0.1.json") + invalidBundle2 := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0-bundle-v0.1.json") + attestations := append(invalidBundle, invalidBundle2...) + require.Len(t, attestations, 2) + + results, errMsg, err := verifyAttestations(*a, attestations, sgVerifier, ec) + require.Error(t, err) + require.Contains(t, errMsg, "✗ Sigstore verification failed") + require.Nil(t, results) + }) + + t.Run("attestations fail to verify when cert extensions don't match enforcement criteria", func(t *testing.T) { + sgjAttestation := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0_with_2_bundles.jsonl") + reusableWorkflowAttestations := getAttestationsFor(t, "../test/data/reusable-workflow-attestation.sigstore.json") + attestations := []*api.Attestation{sgjAttestation[0], reusableWorkflowAttestations[0], sgjAttestation[1]} + require.Len(t, attestations, 3) + + rwfResult := verification.BuildMockResult(reusableWorkflowAttestations[0].Bundle, "", "", "https://github.com/malancas", "", verification.GitHubOIDCIssuer) + sgjResult := verification.BuildSigstoreJsMockResult(t) + mockResults := []*verification.AttestationProcessingResult{&sgjResult, &rwfResult, &sgjResult} + mockSgVerifier := verification.NewMockSigstoreVerifierWithMockResults(t, mockResults) + + // we want to test that attestations that pass Sigstore verification but fail + // cert extension verification are filtered out properly in the second step + // in verifyAttestations. By using a mock Sigstore verifier, we can ensure + // that the call to verification.VerifyCertExtensions in verifyAttestations + // is filtering out attestations as expected + results, errMsg, err := verifyAttestations(*a, attestations, mockSgVerifier, ec) + require.NoError(t, err) + require.Zero(t, errMsg) + require.Len(t, results, 2) + for _, result := range results { + require.NotEqual(t, result.Attestation.Bundle, reusableWorkflowAttestations[0].Bundle) + } + }) + + t.Run("fails verification when cert extension verification fails", func(t *testing.T) { + attestations := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0_with_2_bundles.jsonl") + require.Len(t, attestations, 2) + + expectedCriteria := ec + expectedCriteria.Certificate.SourceRepositoryOwnerURI = "https://github.com/wrong" + + results, errMsg, err := verifyAttestations(*a, attestations, sgVerifier, expectedCriteria) + require.Error(t, err) + require.Contains(t, errMsg, "✗ Policy verification failed") + require.Nil(t, results) + }) +} From 06d22d96c01afae5aaaa4ff8e7895f39fa52d293 Mon Sep 17 00:00:00 2001 From: Barak Amar Date: Fri, 4 Apr 2025 11:14:02 +0300 Subject: [PATCH 19/51] handle find pr number 0 --- pkg/cmd/pr/shared/finder.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index a54528527..6e0ea0401 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -245,7 +245,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err } var pr *api.PullRequest - if f.prNumber > 0 { + if f.prNumber > 0 || f.branchName == "" { if numberFieldOnly { // avoid hitting the API if we already have all the information return &api.PullRequest{Number: f.prNumber}, f.baseRefRepo, nil From 747f015f48e1c063fae68b6041bf39cd83bea2b8 Mon Sep 17 00:00:00 2001 From: Barak Amar Date: Mon, 7 Apr 2025 22:38:28 +0300 Subject: [PATCH 20/51] test pr number 0 --- pkg/cmd/pr/shared/finder_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index 36551ab42..25a948416 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -89,6 +89,19 @@ func TestFind(t *testing.T) { wantPR: 13, wantRepo: "https://github.com/ORIGINOWNER/REPO", }, + { + name: "PR number 0 is invalid", + args: args{ + selector: "0", + fields: []string{"id", "number"}, + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), + branchFn: func() (string, error) { + return "blueberries", nil + }, + branchConfig: stubBranchConfig(git.BranchConfig{}, nil), + }, + wantErr: true, + }, { name: "number argument with base branch", args: args{ From a1f5d42283071d42f6725518f1cd4f51955c8e0d Mon Sep 17 00:00:00 2001 From: Barak Amar Date: Thu, 17 Apr 2025 17:13:28 +0300 Subject: [PATCH 21/51] Update the test code to align with latest changes --- pkg/cmd/pr/shared/finder_test.go | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index 09c2bf7a7..66fb900eb 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -83,19 +83,6 @@ func TestFind(t *testing.T) { wantPR: 13, wantRepo: "https://github.com/ORIGINOWNER/REPO", }, - { - name: "PR number 0 is invalid", - args: args{ - selector: "0", - fields: []string{"id", "number"}, - baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), - branchFn: func() (string, error) { - return "blueberries", nil - }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - }, - wantErr: true, - }, { name: "number argument with base branch", args: args{ @@ -178,6 +165,25 @@ func TestFind(t *testing.T) { wantPR: 13, wantRepo: "https://github.com/ORIGINOWNER/REPO", }, + { + name: "pr number zero", + args: args{ + selector: "0", + fields: []string{"number"}, + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), + branchFn: func() (string, error) { + return "blueberries", nil + }, + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, + }, + httpStub: nil, + wantPR: 0, + wantRepo: "https://github.com/ORIGINOWNER/REPO", + }, { name: "number with hash argument", args: args{ From 274a09bbc97657163b2eabf64a5c979987fd64c3 Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Wed, 23 Apr 2025 10:11:51 -0400 Subject: [PATCH 22/51] Initial `gh accessibility` command draft This commit captures the initial command along with functionality and description. There is an internal discussion about the appropriate place for some of this content. --- pkg/cmd/accessibility/accessibility.go | 150 +++++++++++++++++++++++++ pkg/cmd/root/help.go | 11 +- pkg/cmd/root/root.go | 2 + 3 files changed, 161 insertions(+), 2 deletions(-) create mode 100644 pkg/cmd/accessibility/accessibility.go diff --git a/pkg/cmd/accessibility/accessibility.go b/pkg/cmd/accessibility/accessibility.go new file mode 100644 index 000000000..4992d488b --- /dev/null +++ b/pkg/cmd/accessibility/accessibility.go @@ -0,0 +1,150 @@ +package accessibility + +import ( + "fmt" + + "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/internal/browser" + "github.com/cli/cli/v2/internal/text" + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/iostreams" + "github.com/spf13/cobra" +) + +const ( + communityURL = "https://github.com/orgs/community/discussions/categories/accessibility" +) + +type AccessibilityOptions struct { + IO *iostreams.IOStreams + Browser browser.Browser + Web bool +} + +func NewCmdAccessibility(f *cmdutil.Factory) *cobra.Command { + opts := AccessibilityOptions{ + IO: f.IOStreams, + Browser: f.Browser, + } + + cmd := &cobra.Command{ + Use: "accessibility", + Aliases: []string{"a11y"}, + Short: "Learn about GitHub CLI accessibility experience", + Long: longDescription(opts.IO), + Hidden: true, + RunE: func(cmd *cobra.Command, args []string) error { + if opts.Web { + if opts.IO.IsStdoutTTY() { + fmt.Fprintf(opts.IO.ErrOut, "Opening %s in your browser.\n", text.DisplayURL(communityURL)) + } + return opts.Browser.Browse(communityURL) + } + + return cmd.Help() + }, + Example: heredoc.Doc(` + # Open the GitHub Community Accessibility discussions in your browser + $ gh accessibility --web + + # Display color using customizable, 4-bit accessible colors + $ gh config set accessible_colors enabled + + # Display issue and pull request labels using RGB hex color codes in terminals that support 24-bit truecolor + $ gh config set color_labels enabled + + # Use input prompts without redrawing the screen + $ gh config set accessible_prompter enabled + + # Disable motion-based spinners for progress indicators in favor of text + $ gh config set spinner disabled + `), + } + + cmd.Flags().BoolVarP(&opts.Web, "web", "w", false, "Open the GitHub Community Accessibility discussions in the browser") + cmdutil.DisableAuthCheck(cmd) + + return cmd +} + +func longDescription(io *iostreams.IOStreams) string { + cs := io.ColorScheme() + title := cs.Bold("LEARN ABOUT GITHUB CLI ACCESSIBILITY EFFORTS") + color := cs.Bold("CUSTOMIZABLE AND CONTRASTING COLORS") + prompter := cs.Bold("NON-INTERACTIVE USER INPUT PROMPTING") + spinner := cs.Bold("TEXT-BASED SPINNERS") + + return heredoc.Docf(` + %[2]s + + As the home for all developers, we want every developer to feel welcome in our + community and be empowered to contribute to the future of global software + development with everything GitHub has to offer including the GitHub CLI. + + We invite you to join us in improving GitHub CLI accessibility by sharing your + feedback and ideas in the GitHub Community Accessibility discussions: + %[3]s + + + %[4]s + + Color is a common approach to enhance user experiences, however users can find + themselves with a worse experience due to insufficient contrast or + customizability. + + To create an accessible experience, CLIs should use color palettes based on + terminal background appearance and limit colors to 4-bit ANSI color palettes, + which users can customize within terminal preferences. + + With this new experience, the GitHub CLI provides multiple options to address + color usage: + + 1. The GitHub CLI will use 4-bit color palette for increased color contrast based on + dark and light backgrounds including rendering markdown based on GitHub Primer. + + To enable this experience, use one of the following methods: + - Run %[1]sgh config set accessible_colors enabled%[1]s + - Set %[1]sGH_ACCESSIBLE_COLORS=enabled%[1]s environment variable + + 2. The GitHub CLI will display issue and pull request labels' custom RGB colors + in terminals with truecolor support. + + To enable this experience, use one of the following methods: + - Run %[1]sgh config set color_labels enabled%[1]s + - Set %[1]sGH_COLOR_LABELS=enabled%[1]s environment variable + + + %[5]s + + Interactive text user interfaces are an advanced approach to enhance user + experiences, which manipulate the terminal cursor to redraw parts of the screen. + However, this can be difficult for speech synthesizers or braille displays to + accurately detect and read. + + To create an accessible experience, CLIs should give users the ability to disable + this interactivity while providing a similar experience. + + With this new experience, the GitHub CLI will use non-interactive prompts for + user input. + + To enable this experience, use one of the following methods: + - Run %[1]sgh config set accessible_prompter enabled%[1]s + - Set %[1]sGH_ACCESSIBLE_PROMPTER=enabled%[1]s environment variable + + + %[6]s + + Motion-based spinners are a common approach to communicate activity, which + manipulate the terminal cursor to create a spinning effect. However, this can be + difficult for users with motion sensitivity as well as speech synthesizers. + + To create an accessible experience, CLIs should give users the ability to disable + this interactivity while providing a similar experience. + + With this new experience, the GitHub CLI will use text-based progress indicators. + + To enable this experience, use one of the following methods: + - Run %[1]sgh config set spinner disabled%[1]s + - Set %[1]sGH_SPINNER_DISABLED=yes%[1]s environment variable + `, "`", title, communityURL, color, prompter, spinner) +} diff --git a/pkg/cmd/root/help.go b/pkg/cmd/root/help.go index 7f8fb1c2e..ec6499f21 100644 --- a/pkg/cmd/root/help.go +++ b/pkg/cmd/root/help.go @@ -109,8 +109,6 @@ func rootHelpFunc(f *cmdutil.Factory, command *cobra.Command, _ []string) { return } - namePadding := 12 - type helpEntry struct { Title string Body string @@ -135,6 +133,12 @@ func rootHelpFunc(f *cmdutil.Factory, command *cobra.Command, _ []string) { helpEntries = append(helpEntries, helpEntry{"ALIASES", strings.Join(BuildAliasList(command, command.Aliases), ", ") + "\n"}) } + // Statically calculated padding for non-extension commands, + // longest is `gh accessibility` with 13 characters + 1 space. + // + // Should consider novel way to calculate this in the future [AF] + namePadding := 14 + for _, g := range GroupedCommands(command) { var names []string for _, c := range g.Commands { @@ -148,6 +152,9 @@ func rootHelpFunc(f *cmdutil.Factory, command *cobra.Command, _ []string) { if isRootCmd(command) { var helpTopics []string + if c := findCommand(command, "accessibility"); c != nil { + helpTopics = append(helpTopics, rpad(c.Name()+":", namePadding)+c.Short) + } if c := findCommand(command, "actions"); c != nil { helpTopics = append(helpTopics, rpad(c.Name()+":", namePadding)+c.Short) } diff --git a/pkg/cmd/root/root.go b/pkg/cmd/root/root.go index c0dad93ec..8cf30db1b 100644 --- a/pkg/cmd/root/root.go +++ b/pkg/cmd/root/root.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/MakeNowJust/heredoc" + accessibilityCmd "github.com/cli/cli/v2/pkg/cmd/accessibility" actionsCmd "github.com/cli/cli/v2/pkg/cmd/actions" aliasCmd "github.com/cli/cli/v2/pkg/cmd/alias" "github.com/cli/cli/v2/pkg/cmd/alias/shared" @@ -122,6 +123,7 @@ func NewCmdRoot(f *cmdutil.Factory, version, buildDate string) (*cobra.Command, // Child commands cmd.AddCommand(versionCmd.NewCmdVersion(f, version, buildDate)) + cmd.AddCommand(accessibilityCmd.NewCmdAccessibility(f)) cmd.AddCommand(actionsCmd.NewCmdActions(f)) cmd.AddCommand(aliasCmd.NewCmdAlias(f)) cmd.AddCommand(authCmd.NewCmdAuth(f)) From fb97b3efaabaf3a727beb6ed5f4adbf9e780f9ff Mon Sep 17 00:00:00 2001 From: William Martin Date: Thu, 24 Apr 2025 18:41:14 +0200 Subject: [PATCH 23/51] Fix pr create when push.default tracking and no merge ref (#10863) * Fix pr create when push.default tracking and no merge ref * Update pkg/cmd/pr/shared/find_refs_resolution.go --------- Co-authored-by: Tyler McGoffin --- ...h-default-upstream-no-merge-ref-fork.txtar | 50 +++++++++++++++++++ ...e-push-default-upstream-no-merge-ref.txtar | 33 ++++++++++++ pkg/cmd/pr/shared/find_refs_resolution.go | 8 +-- .../pr/shared/find_refs_resolution_test.go | 8 +-- 4 files changed, 92 insertions(+), 7 deletions(-) create mode 100644 acceptance/testdata/pr/pr-create-push-default-upstream-no-merge-ref-fork.txtar create mode 100644 acceptance/testdata/pr/pr-create-push-default-upstream-no-merge-ref.txtar diff --git a/acceptance/testdata/pr/pr-create-push-default-upstream-no-merge-ref-fork.txtar b/acceptance/testdata/pr/pr-create-push-default-upstream-no-merge-ref-fork.txtar new file mode 100644 index 000000000..0974f9225 --- /dev/null +++ b/acceptance/testdata/pr/pr-create-push-default-upstream-no-merge-ref-fork.txtar @@ -0,0 +1,50 @@ +skip 'it creates a fork owned by the user running the test' +skip 'this never worked, but could be fixed if we fixed show-refs' + +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} +env FORK=${REPO}-fork + +# Use gh as a credential helper +exec gh auth setup-git + +# Get the current username for the fork owner +exec gh api user --jq .login +stdout2env USER + +# Create a repository to act as upstream with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup of upstream +defer gh repo delete --yes ${ORG}/${REPO} + +# Create a user fork of repository. This will be owned by USER. +exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 + +# Defer repo cleanup of fork +defer gh repo delete --yes ${USER}/${FORK} + +# Retrieve fork repository information +exec gh repo view ${USER}/${FORK} --json id --jq '.id' +stdout2env FORK_ID + +# Clone the repo +exec gh repo clone ${USER}/${FORK} +cd ${FORK} + +# Configure push.default so that it should use the merge ref +exec git config push.default upstream + +# But prepare a branch that doesn't have a tracking merge ref +exec git checkout -b feature-branch +exec git commit --allow-empty -m 'Empty Commit' +exec git push origin feature-branch + +# Create the PR +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 + +# Assert that the PR was created with the correct head repository and refs +exec gh pr view --json headRefName,headRepository,baseRefName,isCrossRepository +stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true} diff --git a/acceptance/testdata/pr/pr-create-push-default-upstream-no-merge-ref.txtar b/acceptance/testdata/pr/pr-create-push-default-upstream-no-merge-ref.txtar new file mode 100644 index 000000000..90c5cde50 --- /dev/null +++ b/acceptance/testdata/pr/pr-create-push-default-upstream-no-merge-ref.txtar @@ -0,0 +1,33 @@ +skip 'it creates a fork owned by the user running the test' + +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} + +# Use gh as a credential helper +exec gh auth setup-git + +# Get the current username for the fork owner +exec gh api user --jq .login +stdout2env USER + +# Create a repository to act as upstream with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup of upstream +defer gh repo delete --yes ${ORG}/${REPO} + +# Clone the repo +exec gh repo clone ${ORG}/${REPO} +cd ${REPO} + +# Configure push.default so that it should use the merge ref +exec git config push.default upstream + +# But prepare a branch that doesn't have a tracking merge ref +exec git checkout -b feature-branch +exec git commit --allow-empty -m 'Empty Commit' +exec git push origin feature-branch + +# Create the PR +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 diff --git a/pkg/cmd/pr/shared/find_refs_resolution.go b/pkg/cmd/pr/shared/find_refs_resolution.go index 833075af8..e4e51bab8 100644 --- a/pkg/cmd/pr/shared/find_refs_resolution.go +++ b/pkg/cmd/pr/shared/find_refs_resolution.go @@ -333,12 +333,12 @@ func tryDetermineDefaultPushTarget(gitClient GitConfigClient, localBranchName st } // We assume the PR's branch name is the same as whatever was provided, unless the user has specified - // push.default = upstream or tracking, then we use the branch name from the merge ref. + // push.default = upstream or tracking, then we use the branch name from the merge ref if it exists. Otherwise, we fall back to the local branch name remoteBranch := localBranchName if pushDefault == git.PushDefaultUpstream || pushDefault == git.PushDefaultTracking { - remoteBranch = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") - if remoteBranch == "" { - return defaultPushTarget{}, fmt.Errorf("could not determine remote branch name") + mergeRef := strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") + if mergeRef != "" { + remoteBranch = mergeRef } } diff --git a/pkg/cmd/pr/shared/find_refs_resolution_test.go b/pkg/cmd/pr/shared/find_refs_resolution_test.go index 8cbb62146..d2393bf10 100644 --- a/pkg/cmd/pr/shared/find_refs_resolution_test.go +++ b/pkg/cmd/pr/shared/find_refs_resolution_test.go @@ -462,7 +462,7 @@ func TestTryDetermineDefaultPRHead(t *testing.T) { }) } - t.Run("but if the merge ref is empty, error", func(t *testing.T) { + t.Run("but if the merge ref is empty, use the provided branch name", func(t *testing.T) { t.Parallel() repoResolvedFromPushRemoteClient := stubGitConfigClient{ @@ -474,12 +474,14 @@ func TestTryDetermineDefaultPRHead(t *testing.T) { pushDefaultFn: stubPushDefault(git.PushDefaultUpstream, nil), } - _, err := TryDetermineDefaultPRHead( + defaultPRHead, err := TryDetermineDefaultPRHead( repoResolvedFromPushRemoteClient, stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), "feature-branch", ) - require.Error(t, err) + require.NoError(t, err) + + require.Equal(t, "feature-branch", defaultPRHead.BranchName) }) }) From abd98bd727521286e0d0179a1c4818d38eb74ab7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 25 Apr 2025 15:00:36 +0000 Subject: [PATCH 24/51] chore(deps): bump github.com/cpuguy83/go-md2man/v2 from 2.0.6 to 2.0.7 Bumps [github.com/cpuguy83/go-md2man/v2](https://github.com/cpuguy83/go-md2man) from 2.0.6 to 2.0.7. - [Release notes](https://github.com/cpuguy83/go-md2man/releases) - [Commits](https://github.com/cpuguy83/go-md2man/compare/v2.0.6...v2.0.7) --- updated-dependencies: - dependency-name: github.com/cpuguy83/go-md2man/v2 dependency-version: 2.0.7 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 31b07f2cf..1ea50709d 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/cli/go-internal v0.0.0-20241025142207-6c48bcd5ce24 github.com/cli/oauth v1.1.1 github.com/cli/safeexec v1.0.1 - github.com/cpuguy83/go-md2man/v2 v2.0.6 + github.com/cpuguy83/go-md2man/v2 v2.0.7 github.com/creack/pty v1.1.24 github.com/digitorus/timestamp v0.0.0-20231217203849-220c5c2851b7 github.com/distribution/reference v0.6.0 diff --git a/go.sum b/go.sum index b312bcf6c..0203b8bfc 100644 --- a/go.sum +++ b/go.sum @@ -142,8 +142,9 @@ github.com/codahale/rfc6979 v0.0.0-20141003034818-6a90f24967eb h1:EDmT6Q9Zs+SbUo github.com/codahale/rfc6979 v0.0.0-20141003034818-6a90f24967eb/go.mod h1:ZjrT6AXHbDs86ZSdt/osfBi5qfexBrKUdONk989Wnk4= github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8= github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= -github.com/cpuguy83/go-md2man/v2 v2.0.6 h1:XJtiaUW6dEEqVuZiMTn1ldk455QWwEIsMIJlo5vtkx0= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= +github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= From 97a3b70599dcba4aae1c24fc36796362505d39ed Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Sat, 26 Apr 2025 12:57:10 -0400 Subject: [PATCH 25/51] Update to huh@0.7.0, echo mode changes This commit is the initial change around updating to huh@0.7.0; pre-testing changes. --- go.mod | 2 +- go.sum | 12 ++++++++++-- internal/prompter/prompter.go | 9 ++++++--- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 31b07f2cf..3562f24a6 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/briandowns/spinner v1.18.1 github.com/cenkalti/backoff/v4 v4.3.0 github.com/charmbracelet/glamour v0.9.2-0.20250319212134-549f544650e3 - github.com/charmbracelet/huh v0.6.1-0.20250409210615-c5906631cbb5 + github.com/charmbracelet/huh v0.7.0 github.com/charmbracelet/lipgloss v1.1.1-0.20250319133953-166f707985bc github.com/cli/go-gh/v2 v2.12.0 github.com/cli/go-internal v0.0.0-20241025142207-6c48bcd5ce24 diff --git a/go.sum b/go.sum index b312bcf6c..2ac25c2f8 100644 --- a/go.sum +++ b/go.sum @@ -110,20 +110,28 @@ github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4p github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= github.com/charmbracelet/glamour v0.9.2-0.20250319212134-549f544650e3 h1:hx6E25SvI2WiZdt/gxINcYBnHD7PE2Vr9auqwg5B05g= github.com/charmbracelet/glamour v0.9.2-0.20250319212134-549f544650e3/go.mod h1:ihVqv4/YOY5Fweu1cxajuQrwJFh3zU4Ukb4mHVNjq3s= -github.com/charmbracelet/huh v0.6.1-0.20250409210615-c5906631cbb5 h1:uOnMxWghHfEYm2DPMeIHHAEirV/TduBVC9ZRXGcX9Q8= -github.com/charmbracelet/huh v0.6.1-0.20250409210615-c5906631cbb5/go.mod h1:xl27E/xNaX3WwdkqpvBwjJcGWhupkU52CWLC5hReBTw= +github.com/charmbracelet/huh v0.7.0 h1:W8S1uyGETgj9Tuda3/JdVkc3x7DBLZYPZc4c+/rnRdc= +github.com/charmbracelet/huh v0.7.0/go.mod h1:UGC3DZHlgOKHvHC07a5vHag41zzhpPFj34U92sOmyuk= github.com/charmbracelet/lipgloss v1.1.1-0.20250319133953-166f707985bc h1:nFRtCfZu/zkltd2lsLUPlVNv3ej/Atod9hcdbRZtlys= github.com/charmbracelet/lipgloss v1.1.1-0.20250319133953-166f707985bc/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA= github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE= github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q= github.com/charmbracelet/x/cellbuf v0.0.13 h1:/KBBKHuVRbq1lYx5BzEHBAFBP8VcQzJejZ/IA3iR28k= github.com/charmbracelet/x/cellbuf v0.0.13/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/conpty v0.1.0 h1:4zc8KaIcbiL4mghEON8D72agYtSeIgq8FSThSPQIb+U= +github.com/charmbracelet/x/conpty v0.1.0/go.mod h1:rMFsDJoDwVmiYM10aD4bH2XiRgwI7NYJtQgl5yskjEQ= +github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA= +github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0= github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ= github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko3AQ4gK1MTS/de7F5hPGx6/k1u0w4TeYmBFwzYVP4= github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ= github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= +github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY= +github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo= +github.com/charmbracelet/x/xpty v0.1.2 h1:Pqmu4TEJ8KeA9uSkISKMU3f+C1F6OGBn8ABuGlqCbtI= +github.com/charmbracelet/x/xpty v0.1.2/go.mod h1:XK2Z0id5rtLWcpeNiMYBccNNBrP2IJnzHI0Lq13Xzq4= github.com/cli/browser v1.0.0/go.mod h1:IEWkHYbLjkhtjwwWlwTHW2lGxeS5gezEQBMLTwDHf5Q= github.com/cli/browser v1.3.0 h1:LejqCrpWr+1pRqmEPDGnTZOjsMe7sehifLynZJuqJpo= github.com/cli/browser v1.3.0/go.mod h1:HH8s+fOAxjhQoBUAsKuPCbqUuxZDhQ2/aD+SzsEfBTk= diff --git a/internal/prompter/prompter.go b/internal/prompter/prompter.go index 2a4328366..d56374665 100644 --- a/internal/prompter/prompter.go +++ b/internal/prompter/prompter.go @@ -137,10 +137,12 @@ func (p *accessiblePrompter) Input(prompt, defaultValue string) (string, error) func (p *accessiblePrompter) Password(prompt string) (string, error) { var result string - // EchoMode(huh.EchoModePassword) doesn't have any effect in accessible mode. + // EchoModeNone and EchoModePassword both result in disabling echo mode + // as password masking is outside of VT100 spec. form := p.newForm( huh.NewGroup( huh.NewInput(). + EchoMode(huh.EchoModeNone). Title(prompt). Value(&result), ), @@ -171,9 +173,12 @@ func (p *accessiblePrompter) Confirm(prompt string, defaultValue bool) (bool, er func (p *accessiblePrompter) AuthToken() (string, error) { var result string + // EchoModeNone and EchoModePassword both result in disabling echo mode + // as password masking is outside of VT100 spec. form := p.newForm( huh.NewGroup( huh.NewInput(). + EchoMode(huh.EchoModeNone). Title("Paste your authentication token:"). // Note: if this validation fails, the prompt loops. Validate(func(input string) error { @@ -183,8 +188,6 @@ func (p *accessiblePrompter) AuthToken() (string, error) { return nil }). Value(&result), - // This doesn't have any effect in accessible mode. - // EchoMode(huh.EchoModePassword), ), ) From 519926b7cf7458df6e12d9f280ae6c2072796e22 Mon Sep 17 00:00:00 2001 From: Antonio Consuegra Date: Mon, 28 Apr 2025 13:54:09 +0200 Subject: [PATCH 26/51] Fix expected error output of TestRepo/repo-set-default --- acceptance/testdata/repo/repo-set-default.txtar | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acceptance/testdata/repo/repo-set-default.txtar b/acceptance/testdata/repo/repo-set-default.txtar index 4f7fa3273..de4eda11f 100644 --- a/acceptance/testdata/repo/repo-set-default.txtar +++ b/acceptance/testdata/repo/repo-set-default.txtar @@ -7,7 +7,7 @@ defer gh repo delete --yes $ORG/$SCRIPT_NAME-$RANDOM_STRING # Ensure that no default is set cd $SCRIPT_NAME-$RANDOM_STRING exec gh repo set-default --view -stderr 'no default repository has been set; use `gh repo set-default` to select one' +stderr 'No default remote repository has been set. To learn more about the default repository, run: gh repo set-default --help' # Set the default exec gh repo set-default $ORG/$SCRIPT_NAME-$RANDOM_STRING From a53b6c074ce66b8df7dd7abddcc713abf15efa1b Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Mon, 28 Apr 2025 08:55:47 -0400 Subject: [PATCH 27/51] Assert password and auth token not displayed This commit expands existing tests (thanks to @babakks) to assert whether the echo mode is actually disabled for password and auth token prompts. --- internal/prompter/accessible_prompter_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/internal/prompter/accessible_prompter_test.go b/internal/prompter/accessible_prompter_test.go index 619eb14f1..63f253331 100644 --- a/internal/prompter/accessible_prompter_test.go +++ b/internal/prompter/accessible_prompter_test.go @@ -134,6 +134,11 @@ func TestAccessiblePrompter(t *testing.T) { passwordValue, err := p.Password("Enter password") require.NoError(t, err) require.Equal(t, dummyPassword, passwordValue) + + // Ensure the dummy password is not printed to the screen, + // asserting that echo mode is disabled without OS-level tests. + _, err = console.ExpectString(" \r\n\r\n") + require.NoError(t, err) }) t.Run("Confirm", func(t *testing.T) { @@ -192,6 +197,11 @@ func TestAccessiblePrompter(t *testing.T) { authValue, err := p.AuthToken() require.NoError(t, err) require.Equal(t, dummyAuthToken, authValue) + + // Ensure the dummy password is not printed to the screen, + // asserting that echo mode is disabled without OS-level tests. + _, err = console.ExpectString(" \r\n\r\n") + require.NoError(t, err) }) t.Run("AuthToken - blank input returns error", func(t *testing.T) { @@ -220,6 +230,11 @@ func TestAccessiblePrompter(t *testing.T) { authValue, err := p.AuthToken() require.NoError(t, err) require.Equal(t, dummyAuthTokenForAfterFailure, authValue) + + // Ensure the dummy password is not printed to the screen, + // asserting that echo mode is disabled without OS-level tests. + _, err = console.ExpectString(" \r\n\r\n") + require.NoError(t, err) }) t.Run("ConfirmDeletion", func(t *testing.T) { From 9fa00c350bb4a09dca6ed76d607ed8baef3d9d2e Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Mon, 28 Apr 2025 10:17:23 -0400 Subject: [PATCH 28/51] Update accessible tests based on huh@0.7.0 changes --- internal/prompter/accessible_prompter_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/prompter/accessible_prompter_test.go b/internal/prompter/accessible_prompter_test.go index 63f253331..2211d7720 100644 --- a/internal/prompter/accessible_prompter_test.go +++ b/internal/prompter/accessible_prompter_test.go @@ -38,7 +38,7 @@ func TestAccessiblePrompter(t *testing.T) { go func() { // Wait for prompt to appear - _, err := console.ExpectString("Choose:") + _, err := console.ExpectString("Input a number between 1 and 3:") require.NoError(t, err) // Select option 1 @@ -57,7 +57,7 @@ func TestAccessiblePrompter(t *testing.T) { go func() { // Wait for prompt to appear - _, err := console.ExpectString("Select a number") + _, err := console.ExpectString("Input a number between 0 and 3:") require.NoError(t, err) // Select options 1 and 2 @@ -340,7 +340,7 @@ func TestAccessiblePrompter(t *testing.T) { require.NoError(t, err) // Expect a notice to enter something valid since blank is disallowed. - _, err = console.ExpectString("invalid input. please try again") + _, err = console.ExpectString("Invalid: must be between 1 and 1") require.NoError(t, err) // Send a 1 to select to open the editor. This will immediately exit @@ -367,7 +367,7 @@ func TestAccessiblePrompter(t *testing.T) { require.NoError(t, err) // Expect a notice to enter something valid since blank is disallowed. - _, err = console.ExpectString("invalid input. please try again") + _, err = console.ExpectString("Invalid: must be between 1 and 1") require.NoError(t, err) // Send a 1 to select to open the editor since skip is invalid and From 2d66877d6c447ffb20a602c338e6aab13c916035 Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Mon, 28 Apr 2025 11:15:28 -0400 Subject: [PATCH 29/51] Update internal/prompter/accessible_prompter_test.go --- internal/prompter/accessible_prompter_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/prompter/accessible_prompter_test.go b/internal/prompter/accessible_prompter_test.go index 2211d7720..3c8420cde 100644 --- a/internal/prompter/accessible_prompter_test.go +++ b/internal/prompter/accessible_prompter_test.go @@ -136,7 +136,7 @@ func TestAccessiblePrompter(t *testing.T) { require.Equal(t, dummyPassword, passwordValue) // Ensure the dummy password is not printed to the screen, - // asserting that echo mode is disabled without OS-level tests. + // asserting that echo mode is disabled. _, err = console.ExpectString(" \r\n\r\n") require.NoError(t, err) }) From df0aedbe3c0e5e474f7b5238354cf2ba22eecd3a Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Mon, 28 Apr 2025 11:16:35 -0400 Subject: [PATCH 30/51] Update internal/prompter/prompter.go --- internal/prompter/prompter.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/prompter/prompter.go b/internal/prompter/prompter.go index d56374665..c3281efbf 100644 --- a/internal/prompter/prompter.go +++ b/internal/prompter/prompter.go @@ -137,8 +137,8 @@ func (p *accessiblePrompter) Input(prompt, defaultValue string) (string, error) func (p *accessiblePrompter) Password(prompt string) (string, error) { var result string - // EchoModeNone and EchoModePassword both result in disabling echo mode - // as password masking is outside of VT100 spec. + // EchoModePassword is not used as password masking is unsupported in huh. + // EchoModeNone and EchoModePassword have the same effect of hiding user input. form := p.newForm( huh.NewGroup( huh.NewInput(). From 88d52ebf97bcbfe48a064206709e0adcb0c92922 Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Mon, 28 Apr 2025 11:20:17 -0400 Subject: [PATCH 31/51] Fix other disabled echo mode comments --- internal/prompter/accessible_prompter_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/prompter/accessible_prompter_test.go b/internal/prompter/accessible_prompter_test.go index 3c8420cde..c95d379e3 100644 --- a/internal/prompter/accessible_prompter_test.go +++ b/internal/prompter/accessible_prompter_test.go @@ -199,7 +199,7 @@ func TestAccessiblePrompter(t *testing.T) { require.Equal(t, dummyAuthToken, authValue) // Ensure the dummy password is not printed to the screen, - // asserting that echo mode is disabled without OS-level tests. + // asserting that echo mode is disabled. _, err = console.ExpectString(" \r\n\r\n") require.NoError(t, err) }) @@ -232,7 +232,7 @@ func TestAccessiblePrompter(t *testing.T) { require.Equal(t, dummyAuthTokenForAfterFailure, authValue) // Ensure the dummy password is not printed to the screen, - // asserting that echo mode is disabled without OS-level tests. + // asserting that echo mode is disabled. _, err = console.ExpectString(" \r\n\r\n") require.NoError(t, err) }) From 9bb89de87c7fdeca18ccfc9b338e110c1ba676a9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Apr 2025 15:44:19 +0000 Subject: [PATCH 32/51] chore(deps): bump actions/attest-build-provenance from 2.2.2 to 2.3.0 Bumps [actions/attest-build-provenance](https://github.com/actions/attest-build-provenance) from 2.2.2 to 2.3.0. - [Release notes](https://github.com/actions/attest-build-provenance/releases) - [Changelog](https://github.com/actions/attest-build-provenance/blob/main/RELEASE.md) - [Commits](https://github.com/actions/attest-build-provenance/compare/bd77c077858b8d561b7a36cbe48ef4cc642ca39d...db473fddc028af60658334401dc6fa3ffd8669fd) --- updated-dependencies: - dependency-name: actions/attest-build-provenance dependency-version: 2.3.0 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/deployment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/deployment.yml b/.github/workflows/deployment.yml index a7b03f40d..850cc19b7 100644 --- a/.github/workflows/deployment.yml +++ b/.github/workflows/deployment.yml @@ -309,7 +309,7 @@ jobs: rpmsign --addsign dist/*.rpm - name: Attest release artifacts if: inputs.environment == 'production' - uses: actions/attest-build-provenance@bd77c077858b8d561b7a36cbe48ef4cc642ca39d # v2.2.2 + uses: actions/attest-build-provenance@db473fddc028af60658334401dc6fa3ffd8669fd # v2.3.0 with: subject-path: "dist/gh_*" - name: Run createrepo From d7e2468286db92630a0393386d717aec3c46f0fb Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Mon, 28 Apr 2025 15:01:15 -0400 Subject: [PATCH 33/51] Update a11y text based on draft feedback --- pkg/cmd/accessibility/accessibility.go | 89 +++++++++++++------------- pkg/cmd/root/help.go | 1 + 2 files changed, 44 insertions(+), 46 deletions(-) diff --git a/pkg/cmd/accessibility/accessibility.go b/pkg/cmd/accessibility/accessibility.go index 4992d488b..1d4a8009f 100644 --- a/pkg/cmd/accessibility/accessibility.go +++ b/pkg/cmd/accessibility/accessibility.go @@ -12,7 +12,7 @@ import ( ) const ( - communityURL = "https://github.com/orgs/community/discussions/categories/accessibility" + feedbackURL = "https://accessibility.github.com/feedback" ) type AccessibilityOptions struct { @@ -30,27 +30,27 @@ func NewCmdAccessibility(f *cmdutil.Factory) *cobra.Command { cmd := &cobra.Command{ Use: "accessibility", Aliases: []string{"a11y"}, - Short: "Learn about GitHub CLI accessibility experience", + Short: "Learn about GitHub CLI accessibility experiences", Long: longDescription(opts.IO), Hidden: true, RunE: func(cmd *cobra.Command, args []string) error { if opts.Web { if opts.IO.IsStdoutTTY() { - fmt.Fprintf(opts.IO.ErrOut, "Opening %s in your browser.\n", text.DisplayURL(communityURL)) + fmt.Fprintf(opts.IO.ErrOut, "Opening %s in your browser.\n", text.DisplayURL(feedbackURL)) } - return opts.Browser.Browse(communityURL) + return opts.Browser.Browse(feedbackURL) } return cmd.Help() }, Example: heredoc.Doc(` - # Open the GitHub Community Accessibility discussions in your browser + # Open the GitHub Accessibility site in your browser $ gh accessibility --web # Display color using customizable, 4-bit accessible colors $ gh config set accessible_colors enabled - # Display issue and pull request labels using RGB hex color codes in terminals that support 24-bit truecolor + # Display issue and pull request labels using RGB hex color codes in terminals that support 24-bit true color $ gh config set color_labels enabled # Use input prompts without redrawing the screen @@ -61,7 +61,7 @@ func NewCmdAccessibility(f *cmdutil.Factory) *cobra.Command { `), } - cmd.Flags().BoolVarP(&opts.Web, "web", "w", false, "Open the GitHub Community Accessibility discussions in the browser") + cmd.Flags().BoolVarP(&opts.Web, "web", "w", false, "Open the GitHub Accessibility site in your browser") cmdutil.DisableAuthCheck(cmd) return cmd @@ -69,10 +69,11 @@ func NewCmdAccessibility(f *cmdutil.Factory) *cobra.Command { func longDescription(io *iostreams.IOStreams) string { cs := io.ColorScheme() - title := cs.Bold("LEARN ABOUT GITHUB CLI ACCESSIBILITY EFFORTS") - color := cs.Bold("CUSTOMIZABLE AND CONTRASTING COLORS") - prompter := cs.Bold("NON-INTERACTIVE USER INPUT PROMPTING") - spinner := cs.Bold("TEXT-BASED SPINNERS") + title := cs.Bold("Learn about GitHub CLI accessibility experiences") + color := cs.Bold("Customizable and contrasting colors") + prompter := cs.Bold("Non-interactive user input prompting") + spinner := cs.Bold("Text-based spinners") + feedback := cs.Bold("Join the conversation") return heredoc.Docf(` %[2]s @@ -81,70 +82,66 @@ func longDescription(io *iostreams.IOStreams) string { community and be empowered to contribute to the future of global software development with everything GitHub has to offer including the GitHub CLI. - We invite you to join us in improving GitHub CLI accessibility by sharing your - feedback and ideas in the GitHub Community Accessibility discussions: %[3]s + Text interfaces often use color for various purposes, but insufficient contrast + or customizability can leave some users unable to benefit. - %[4]s - - Color is a common approach to enhance user experiences, however users can find - themselves with a worse experience due to insufficient contrast or - customizability. - - To create an accessible experience, CLIs should use color palettes based on - terminal background appearance and limit colors to 4-bit ANSI color palettes, - which users can customize within terminal preferences. + To create a more accessible experience, the GitHub CLI will use color palettes + based on terminal background appearance and limit colors to 4-bit ANSI color + palettes, which users can customize within terminal preferences. With this new experience, the GitHub CLI provides multiple options to address color usage: - 1. The GitHub CLI will use 4-bit color palette for increased color contrast based on - dark and light backgrounds including rendering markdown based on GitHub Primer. + 1. The GitHub CLI will use 4-bit color palette for increased color contrast based + on dark and light backgrounds including rendering Markdown based on the + GitHub Primer design system. To enable this experience, use one of the following methods: - Run %[1]sgh config set accessible_colors enabled%[1]s - Set %[1]sGH_ACCESSIBLE_COLORS=enabled%[1]s environment variable 2. The GitHub CLI will display issue and pull request labels' custom RGB colors - in terminals with truecolor support. + in terminals with true color support. To enable this experience, use one of the following methods: - Run %[1]sgh config set color_labels enabled%[1]s - Set %[1]sGH_COLOR_LABELS=enabled%[1]s environment variable + %[4]s - %[5]s + Interactive text user interfaces manipulate the terminal cursor to redraw parts + of the screen, which can be difficult for speech synthesizers or braille displays + to accurately detect and read. - Interactive text user interfaces are an advanced approach to enhance user - experiences, which manipulate the terminal cursor to redraw parts of the screen. - However, this can be difficult for speech synthesizers or braille displays to - accurately detect and read. - - To create an accessible experience, CLIs should give users the ability to disable - this interactivity while providing a similar experience. - - With this new experience, the GitHub CLI will use non-interactive prompts for - user input. + To create a more accessible experience, the GitHub CLI gives users the ability to + disable this interactivity while providing a similar experience using + non-interactive prompts for user input. To enable this experience, use one of the following methods: - Run %[1]sgh config set accessible_prompter enabled%[1]s - Set %[1]sGH_ACCESSIBLE_PROMPTER=enabled%[1]s environment variable + %[5]s - %[6]s + Motion-based spinners communicate in-progress activity by manipulating the + terminal cursor to create a spinning effect, which can be difficult for users + with motion sensitivity or miscommunicate information to speech synthesizers. - Motion-based spinners are a common approach to communicate activity, which - manipulate the terminal cursor to create a spinning effect. However, this can be - difficult for users with motion sensitivity as well as speech synthesizers. - - To create an accessible experience, CLIs should give users the ability to disable - this interactivity while providing a similar experience. - - With this new experience, the GitHub CLI will use text-based progress indicators. + To create a more accessible experience, the GitHub CLI gives users the ability to + disable this interactivity while providing a similar experience using text-based + progress indicators. To enable this experience, use one of the following methods: - Run %[1]sgh config set spinner disabled%[1]s - Set %[1]sGH_SPINNER_DISABLED=yes%[1]s environment variable - `, "`", title, communityURL, color, prompter, spinner) + + %[6]s + + We invite you to join us in improving GitHub CLI accessibility by sharing your + feedback and ideas through GitHub Accessibility feedback channels: + + %[7]s + `, "`", title, color, prompter, spinner, feedback, feedbackURL) } diff --git a/pkg/cmd/root/help.go b/pkg/cmd/root/help.go index ec6499f21..2676cdd15 100644 --- a/pkg/cmd/root/help.go +++ b/pkg/cmd/root/help.go @@ -190,6 +190,7 @@ func rootHelpFunc(f *cmdutil.Factory, command *cobra.Command, _ []string) { Use %[1]sgh --help%[1]s for more information about a command. Read the manual at https://cli.github.com/manual Learn about exit codes using %[1]sgh help exit-codes%[1]s + Learn about accessibility experiences using %[1]sgh help accessibility%[1]s `, "`")}) out := f.IOStreams.Out From 9ed733fa5e751be1196f133b086e8981a835ee31 Mon Sep 17 00:00:00 2001 From: Azeem Sajid Date: Tue, 29 Apr 2025 15:48:20 +0500 Subject: [PATCH 34/51] Add `closingIssuesReferences` JSON field to `pr view` (#10544) * [gh pr view] Support `closingIssuesReferences` JSON field * Support pagination * Support pagination * Fix typo * Add more fields --- api/export_pr.go | 19 +++++++++++ api/export_pr_test.go | 64 ++++++++++++++++++++++++++++++++++++ api/queries_pr.go | 22 +++++++++++++ api/query_builder.go | 22 +++++++++++++ pkg/cmd/pr/shared/finder.go | 44 +++++++++++++++++++++++++ pkg/cmd/pr/view/view_test.go | 1 + 6 files changed, 172 insertions(+) diff --git a/api/export_pr.go b/api/export_pr.go index bb3310811..7ae1a4ff4 100644 --- a/api/export_pr.go +++ b/api/export_pr.go @@ -139,6 +139,25 @@ func (pr *PullRequest) ExportData(fields []string) map[string]interface{} { } } data[f] = &requests + case "closingIssuesReferences": + items := make([]map[string]interface{}, 0, len(pr.ClosingIssuesReferences.Nodes)) + for _, n := range pr.ClosingIssuesReferences.Nodes { + items = append(items, map[string]interface{}{ + + "id": n.ID, + "number": n.Number, + "url": n.URL, + "repository": map[string]interface{}{ + "id": n.Repository.ID, + "name": n.Repository.Name, + "owner": map[string]interface{}{ + "id": n.Repository.Owner.ID, + "login": n.Repository.Owner.Login, + }, + }, + }) + } + data[f] = items default: sf := fieldByName(v, f) data[f] = sf.Interface() diff --git a/api/export_pr_test.go b/api/export_pr_test.go index b7f4dcddb..09a1dffe8 100644 --- a/api/export_pr_test.go +++ b/api/export_pr_test.go @@ -245,6 +245,70 @@ func TestPullRequest_ExportData(t *testing.T) { } `), }, + { + name: "linked issues", + fields: []string{"closingIssuesReferences"}, + inputJSON: heredoc.Doc(` + { "closingIssuesReferences": { "nodes": [ + { + "id": "I_123", + "number": 123, + "url": "https://github.com/cli/cli/issues/123", + "repository": { + "id": "R_123", + "name": "cli", + "owner": { + "id": "O_123", + "login": "cli" + } + } + }, + { + "id": "I_456", + "number": 456, + "url": "https://github.com/cli/cli/issues/456", + "repository": { + "id": "R_456", + "name": "cli", + "owner": { + "id": "O_456", + "login": "cli" + } + } + } + ] } } + `), + outputJSON: heredoc.Doc(` + { "closingIssuesReferences": [ + { + "id": "I_123", + "number": 123, + "repository": { + "id": "R_123", + "name": "cli", + "owner": { + "id": "O_123", + "login": "cli" + } + }, + "url": "https://github.com/cli/cli/issues/123" + }, + { + "id": "I_456", + "number": 456, + "repository": { + "id": "R_456", + "name": "cli", + "owner": { + "id": "O_456", + "login": "cli" + } + }, + "url": "https://github.com/cli/cli/issues/456" + } + ] } + `), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/api/queries_pr.go b/api/queries_pr.go index aa493b5e9..5b941bb42 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -93,6 +93,8 @@ type PullRequest struct { Reviews PullRequestReviews LatestReviews PullRequestReviews ReviewRequests ReviewRequests + + ClosingIssuesReferences ClosingIssuesReferences } type StatusCheckRollupNode struct { @@ -107,6 +109,26 @@ type CommitStatusCheckRollup struct { Contexts CheckContexts } +type ClosingIssuesReferences struct { + Nodes []struct { + ID string + Number int + URL string + Repository struct { + ID string + Name string + Owner struct { + ID string + Login string + } + } + } + PageInfo struct { + HasNextPage bool + EndCursor string + } +} + // https://docs.github.com/en/graphql/reference/enums#checkrunstate type CheckRunState string diff --git a/api/query_builder.go b/api/query_builder.go index 2112367e3..4c45da3c1 100644 --- a/api/query_builder.go +++ b/api/query_builder.go @@ -132,6 +132,25 @@ var prCommits = shortenQuery(` } `) +var prClosingIssuesReferences = shortenQuery(` + closingIssuesReferences(first: 100) { + nodes { + id, + number, + url, + repository { + id, + name, + owner { + id, + login + } + } + } + pageInfo{hasNextPage,endCursor} + } +`) + var autoMergeRequest = shortenQuery(` autoMergeRequest { authorEmail, @@ -287,6 +306,7 @@ var PullRequestFields = append(sharedIssuePRFields, "baseRefName", "baseRefOid", "changedFiles", + "closingIssuesReferences", "commits", "deletions", "files", @@ -366,6 +386,8 @@ func IssueGraphQL(fields []string) string { q = append(q, StatusCheckRollupGraphQLWithoutCountByState("")) case "statusCheckRollupWithCountByState": // pseudo-field q = append(q, StatusCheckRollupGraphQLWithCountByState()) + case "closingIssuesReferences": + q = append(q, prClosingIssuesReferences) default: q = append(q, field) } diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index 6d36ef816..e6bb7d66a 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -239,6 +239,11 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err return preloadPrComments(httpClient, f.baseRefRepo, pr) }) } + if fields.Contains("closingIssuesReferences") { + g.Go(func() error { + return preloadPrClosingIssuesReferences(httpClient, f.baseRefRepo, pr) + }) + } if fields.Contains("statusCheckRollup") { g.Go(func() error { return preloadPrChecks(httpClient, f.baseRefRepo, pr) @@ -452,6 +457,45 @@ func preloadPrComments(client *http.Client, repo ghrepo.Interface, pr *api.PullR return nil } +func preloadPrClosingIssuesReferences(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { + if !pr.ClosingIssuesReferences.PageInfo.HasNextPage { + return nil + } + + type response struct { + Node struct { + PullRequest struct { + ClosingIssuesReferences api.ClosingIssuesReferences `graphql:"closingIssuesReferences(first: 100, after: $endCursor)"` + } `graphql:"...on PullRequest"` + } `graphql:"node(id: $id)"` + } + + variables := map[string]interface{}{ + "id": githubv4.ID(pr.ID), + "endCursor": githubv4.String(pr.ClosingIssuesReferences.PageInfo.EndCursor), + } + + gql := api.NewClientFromHTTP(client) + + for { + var query response + err := gql.Query(repo.RepoHost(), "closingIssuesReferences", &query, variables) + if err != nil { + return err + } + + pr.ClosingIssuesReferences.Nodes = append(pr.ClosingIssuesReferences.Nodes, query.Node.PullRequest.ClosingIssuesReferences.Nodes...) + + if !query.Node.PullRequest.ClosingIssuesReferences.PageInfo.HasNextPage { + break + } + variables["endCursor"] = githubv4.String(query.Node.PullRequest.ClosingIssuesReferences.PageInfo.EndCursor) + } + + pr.ClosingIssuesReferences.PageInfo.HasNextPage = false + return nil +} + func preloadPrChecks(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { if len(pr.StatusCheckRollup.Nodes) == 0 { return nil diff --git a/pkg/cmd/pr/view/view_test.go b/pkg/cmd/pr/view/view_test.go index e7f572c76..2cd4066b8 100644 --- a/pkg/cmd/pr/view/view_test.go +++ b/pkg/cmd/pr/view/view_test.go @@ -37,6 +37,7 @@ func TestJSONFields(t *testing.T) { "changedFiles", "closed", "closedAt", + "closingIssuesReferences", "comments", "commits", "createdAt", From 692bdaf5784ecc326deb78089a077b8e2c4ddf07 Mon Sep 17 00:00:00 2001 From: Barak Amar Date: Tue, 29 Apr 2025 14:32:51 +0300 Subject: [PATCH 35/51] Apply code review changes --- pkg/cmd/pr/shared/finder.go | 10 ++++++++-- pkg/cmd/pr/shared/finder_test.go | 4 +--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index 9e92c0692..a87d6790f 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -212,7 +212,8 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err } var pr *api.PullRequest - if f.prNumber > 0 || f.branchName == "" { + if f.prNumber > 0 { + // If we have a PR number, let's look it up if numberFieldOnly { // avoid hitting the API if we already have all the information return &api.PullRequest{Number: f.prNumber}, f.baseRefRepo, nil @@ -221,11 +222,16 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err if err != nil { return pr, f.baseRefRepo, err } - } else { + } else if prRefs.BaseRepo() != nil && f.branchName != "" { + // No PR number, but we have a base repo and branch name. pr, err = findForRefs(httpClient, prRefs, opts.States, fields.ToSlice()) if err != nil { return pr, f.baseRefRepo, err } + } else { + // If we don't have a PR number or a base repo and branch name, + // we can't do anything + return nil, f.baseRefRepo, &NotFoundError{fmt.Errorf("no pull requests found")} } g, _ := errgroup.WithContext(context.Background()) diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index 66fb900eb..abc754d1a 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -180,9 +180,7 @@ func TestFind(t *testing.T) { remotePushDefaultFn: stubRemotePushDefault("", nil), }, }, - httpStub: nil, - wantPR: 0, - wantRepo: "https://github.com/ORIGINOWNER/REPO", + wantErr: true, }, { name: "number with hash argument", From d8512a90666afc2e247506777b3319b2ddb820e8 Mon Sep 17 00:00:00 2001 From: Kynan Ware <47394200+BagToad@users.noreply.github.com> Date: Tue, 29 Apr 2025 16:35:04 -0600 Subject: [PATCH 36/51] fix(prompter): respect default MultiSelect a11y prompter --- internal/prompter/accessible_prompter_test.go | 21 +++++++++++++++++++ internal/prompter/prompter.go | 9 ++++++++ 2 files changed, 30 insertions(+) diff --git a/internal/prompter/accessible_prompter_test.go b/internal/prompter/accessible_prompter_test.go index 619eb14f1..a7326752d 100644 --- a/internal/prompter/accessible_prompter_test.go +++ b/internal/prompter/accessible_prompter_test.go @@ -76,6 +76,27 @@ func TestAccessiblePrompter(t *testing.T) { assert.Equal(t, []int{0, 1}, multiSelectValue) }) + t.Run("MultiSelect - default values are respected by being pre-selected", func(t *testing.T) { + console := newTestVirtualTerminal(t) + p := newTestAccessiblePrompter(t, console) + + go func() { + // Wait for prompt to appear + _, err := console.ExpectString("Select a number") + require.NoError(t, err) + + // Don't select anything because the default should be selected. + + // This confirms selections + _, err = console.SendLine("0") + require.NoError(t, err) + }() + + multiSelectValue, err := p.MultiSelect("Select a number", []string{"2"}, []string{"1", "2", "3"}) + require.NoError(t, err) + assert.Equal(t, []int{1}, multiSelectValue) + }) + t.Run("Input", func(t *testing.T) { console := newTestVirtualTerminal(t) p := newTestAccessiblePrompter(t, console) diff --git a/internal/prompter/prompter.go b/internal/prompter/prompter.go index 2a4328366..322270086 100644 --- a/internal/prompter/prompter.go +++ b/internal/prompter/prompter.go @@ -2,6 +2,7 @@ package prompter import ( "fmt" + "slices" "strings" "github.com/AlecAivazis/survey/v2" @@ -100,6 +101,14 @@ func (p *accessiblePrompter) MultiSelect(prompt string, defaults []string, optio var result []int formOptions := make([]huh.Option[int], len(options)) for i, o := range options { + // If this option is in the defaults slice, + // let's add it's index to the result slice and huh + // will treat it as a default selection. + // TODO: does an invalid default value constitute a panic? + if slices.Contains(defaults, o) { + result = append(result, i) + } + formOptions[i] = huh.NewOption(o, i) } From 00c930d50957c1bededbf6a40b243be4e5a71bab Mon Sep 17 00:00:00 2001 From: Kynan Ware <47394200+BagToad@users.noreply.github.com> Date: Wed, 30 Apr 2025 08:04:16 -0600 Subject: [PATCH 37/51] doc(prompter): small typo --- internal/prompter/prompter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/prompter/prompter.go b/internal/prompter/prompter.go index 2c5c221b0..1e4f5592a 100644 --- a/internal/prompter/prompter.go +++ b/internal/prompter/prompter.go @@ -102,7 +102,7 @@ func (p *accessiblePrompter) MultiSelect(prompt string, defaults []string, optio formOptions := make([]huh.Option[int], len(options)) for i, o := range options { // If this option is in the defaults slice, - // let's add it's index to the result slice and huh + // let's add its index to the result slice and huh // will treat it as a default selection. // TODO: does an invalid default value constitute a panic? if slices.Contains(defaults, o) { From 096106a3d703ce9c8177b6608b58f6338f5478f0 Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Wed, 30 Apr 2025 14:20:16 -0400 Subject: [PATCH 38/51] Apply suggestions from code review Co-authored-by: Melissa Xie --- pkg/cmd/accessibility/accessibility.go | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/pkg/cmd/accessibility/accessibility.go b/pkg/cmd/accessibility/accessibility.go index 1d4a8009f..f05bc7bcc 100644 --- a/pkg/cmd/accessibility/accessibility.go +++ b/pkg/cmd/accessibility/accessibility.go @@ -30,7 +30,7 @@ func NewCmdAccessibility(f *cmdutil.Factory) *cobra.Command { cmd := &cobra.Command{ Use: "accessibility", Aliases: []string{"a11y"}, - Short: "Learn about GitHub CLI accessibility experiences", + Short: "Learn about GitHub CLI's accessibility experiences", Long: longDescription(opts.IO), Hidden: true, RunE: func(cmd *cobra.Command, args []string) error { @@ -87,7 +87,7 @@ func longDescription(io *iostreams.IOStreams) string { Text interfaces often use color for various purposes, but insufficient contrast or customizability can leave some users unable to benefit. - To create a more accessible experience, the GitHub CLI will use color palettes + For a more accessible experience, the GitHub CLI can use color palettes based on terminal background appearance and limit colors to 4-bit ANSI color palettes, which users can customize within terminal preferences. @@ -115,8 +115,7 @@ func longDescription(io *iostreams.IOStreams) string { of the screen, which can be difficult for speech synthesizers or braille displays to accurately detect and read. - To create a more accessible experience, the GitHub CLI gives users the ability to - disable this interactivity while providing a similar experience using + For a more accessible experience, the GitHub CLI can provide a similar experience using non-interactive prompts for user input. To enable this experience, use one of the following methods: @@ -126,12 +125,11 @@ func longDescription(io *iostreams.IOStreams) string { %[5]s Motion-based spinners communicate in-progress activity by manipulating the - terminal cursor to create a spinning effect, which can be difficult for users + terminal cursor to create a spinning effect, which may cause discomfort to users with motion sensitivity or miscommunicate information to speech synthesizers. - To create a more accessible experience, the GitHub CLI gives users the ability to - disable this interactivity while providing a similar experience using text-based - progress indicators. + For a more accessible experience, this interactivity can be disabled in favor + of text-based progress indicators. To enable this experience, use one of the following methods: - Run %[1]sgh config set spinner disabled%[1]s From 2fd1a45a81ae6466dea5598ce9825f984a6fb770 Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Wed, 30 Apr 2025 14:21:02 -0400 Subject: [PATCH 39/51] Update pkg/cmd/accessibility/accessibility.go --- pkg/cmd/accessibility/accessibility.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/cmd/accessibility/accessibility.go b/pkg/cmd/accessibility/accessibility.go index f05bc7bcc..d37631d25 100644 --- a/pkg/cmd/accessibility/accessibility.go +++ b/pkg/cmd/accessibility/accessibility.go @@ -69,7 +69,7 @@ func NewCmdAccessibility(f *cmdutil.Factory) *cobra.Command { func longDescription(io *iostreams.IOStreams) string { cs := io.ColorScheme() - title := cs.Bold("Learn about GitHub CLI accessibility experiences") + title := cs.Bold("Learn about GitHub CLI's accessibility experiences") color := cs.Bold("Customizable and contrasting colors") prompter := cs.Bold("Non-interactive user input prompting") spinner := cs.Bold("Text-based spinners") From c20138d8442457c5c6f326fdbd13d73e47b06a32 Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Wed, 30 Apr 2025 14:35:35 -0400 Subject: [PATCH 40/51] Update pkg/cmd/accessibility/accessibility.go --- pkg/cmd/accessibility/accessibility.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/pkg/cmd/accessibility/accessibility.go b/pkg/cmd/accessibility/accessibility.go index d37631d25..19fb813a4 100644 --- a/pkg/cmd/accessibility/accessibility.go +++ b/pkg/cmd/accessibility/accessibility.go @@ -50,9 +50,6 @@ func NewCmdAccessibility(f *cmdutil.Factory) *cobra.Command { # Display color using customizable, 4-bit accessible colors $ gh config set accessible_colors enabled - # Display issue and pull request labels using RGB hex color codes in terminals that support 24-bit true color - $ gh config set color_labels enabled - # Use input prompts without redrawing the screen $ gh config set accessible_prompter enabled From 830335d9209d399ae1415ec927dd4d00ea1e0ab2 Mon Sep 17 00:00:00 2001 From: Andy Feller Date: Wed, 30 Apr 2025 15:05:07 -0400 Subject: [PATCH 41/51] PR feedback --- pkg/cmd/accessibility/accessibility.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/cmd/accessibility/accessibility.go b/pkg/cmd/accessibility/accessibility.go index 19fb813a4..c5de6c1a4 100644 --- a/pkg/cmd/accessibility/accessibility.go +++ b/pkg/cmd/accessibility/accessibility.go @@ -12,7 +12,7 @@ import ( ) const ( - feedbackURL = "https://accessibility.github.com/feedback" + webURL = "https://accessibility.github.com/conformance/cli/" ) type AccessibilityOptions struct { @@ -36,9 +36,9 @@ func NewCmdAccessibility(f *cmdutil.Factory) *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { if opts.Web { if opts.IO.IsStdoutTTY() { - fmt.Fprintf(opts.IO.ErrOut, "Opening %s in your browser.\n", text.DisplayURL(feedbackURL)) + fmt.Fprintf(opts.IO.ErrOut, "Opening %s in your browser.\n", text.DisplayURL(webURL)) } - return opts.Browser.Browse(feedbackURL) + return opts.Browser.Browse(webURL) } return cmd.Help() @@ -125,7 +125,7 @@ func longDescription(io *iostreams.IOStreams) string { terminal cursor to create a spinning effect, which may cause discomfort to users with motion sensitivity or miscommunicate information to speech synthesizers. - For a more accessible experience, this interactivity can be disabled in favor + For a more accessible experience, this interactivity can be disabled in favor of text-based progress indicators. To enable this experience, use one of the following methods: @@ -138,5 +138,5 @@ func longDescription(io *iostreams.IOStreams) string { feedback and ideas through GitHub Accessibility feedback channels: %[7]s - `, "`", title, color, prompter, spinner, feedback, feedbackURL) + `, "`", title, color, prompter, spinner, feedback, webURL) } From 0a1e7a1fdc68e2824605f0c4c6dd4dbcad448888 Mon Sep 17 00:00:00 2001 From: "Sinan Sonmez (Chaush)" <37421564+sinansonmez@users.noreply.github.com> Date: Thu, 1 May 2025 15:12:55 +0200 Subject: [PATCH 42/51] Add `--delete-last` option to `pr comment` and `issue comment` (#10596) * deletion for issues with confirmation flag * add handling for interaction case * finish implementation for issues * finish the implementation for issues * finalize the implementation for PR * fix missing --yes flag for PR * address PR comments related to feedbacks * improve CommentablePreRun for pre checks * improve confirmation prompt and truncate long comment body * address PR comments on tests * Truncate comment for confirmation prompt Signed-off-by: Babak K. Shandiz * Improve test case descriptions Signed-off-by: Babak K. Shandiz * Fix mock comment body Signed-off-by: Babak K. Shandiz * Remove irrelevant prompt stub Signed-off-by: Babak K. Shandiz * Use `opts.Interactive` as TTY indicator Signed-off-by: Babak K. Shandiz * Fix expected `Interactive` value Signed-off-by: Babak K. Shandiz * Polish `TestNewCmdComment` Signed-off-by: Babak K. Shandiz --------- Signed-off-by: Babak K. Shandiz Co-authored-by: Babak K. Shandiz --- api/queries_comments.go | 25 +++ pkg/cmd/issue/comment/comment.go | 7 +- pkg/cmd/issue/comment/comment_test.go | 218 ++++++++++++++++++++++++- pkg/cmd/pr/comment/comment.go | 7 +- pkg/cmd/pr/comment/comment_test.go | 220 +++++++++++++++++++++++++- pkg/cmd/pr/shared/commentable.go | 75 +++++++++ 6 files changed, 542 insertions(+), 10 deletions(-) diff --git a/api/queries_comments.go b/api/queries_comments.go index 5cc84a3e4..8af17fd2a 100644 --- a/api/queries_comments.go +++ b/api/queries_comments.go @@ -44,6 +44,10 @@ type CommentCreateInput struct { SubjectId string } +type CommentDeleteInput struct { + CommentId string +} + type CommentUpdateInput struct { Body string CommentId string @@ -99,6 +103,27 @@ func CommentUpdate(client *Client, repoHost string, params CommentUpdateInput) ( return mutation.UpdateIssueComment.IssueComment.URL, nil } +func CommentDelete(client *Client, repoHost string, params CommentDeleteInput) error { + var mutation struct { + DeleteIssueComment struct { + ClientMutationID string + } `graphql:"deleteIssueComment(input: $input)"` + } + + variables := map[string]interface{}{ + "input": githubv4.DeleteIssueCommentInput{ + ID: githubv4.ID(params.CommentId), + }, + } + + err := client.Mutate(repoHost, "CommentDelete", &mutation, variables) + if err != nil { + return err + } + + return nil +} + func (c Comment) Identifier() string { return c.ID } diff --git a/pkg/cmd/issue/comment/comment.go b/pkg/cmd/issue/comment/comment.go index 706ff791e..9b7791656 100644 --- a/pkg/cmd/issue/comment/comment.go +++ b/pkg/cmd/issue/comment/comment.go @@ -18,6 +18,7 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*prShared.CommentableOptions) e InteractiveEditSurvey: prShared.CommentableInteractiveEditSurvey(f.Config, f.IOStreams), ConfirmSubmitSurvey: prShared.CommentableConfirmSubmitSurvey(f.Prompter), ConfirmCreateIfNoneSurvey: prShared.CommentableInteractiveCreateIfNoneSurvey(f.Prompter), + ConfirmDeleteLastComment: prShared.CommentableConfirmDeleteLastComment(f.Prompter), OpenInBrowser: f.Browser.Browse, } @@ -63,7 +64,7 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*prShared.CommentableOptions) e } fields := []string{"id", "url"} - if opts.EditLast { + if opts.EditLast || opts.DeleteLast { fields = append(fields, "comments") } @@ -96,7 +97,9 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*prShared.CommentableOptions) e cmd.Flags().StringVarP(&bodyFile, "body-file", "F", "", "Read body text from `file` (use \"-\" to read from standard input)") cmd.Flags().BoolP("editor", "e", false, "Skip prompts and open the text editor to write the body in") cmd.Flags().BoolP("web", "w", false, "Open the web browser to write the comment") - cmd.Flags().BoolVar(&opts.EditLast, "edit-last", false, "Edit the last comment of the same author") + cmd.Flags().BoolVar(&opts.EditLast, "edit-last", false, "Edit the last comment of the current user") + cmd.Flags().BoolVar(&opts.DeleteLast, "delete-last", false, "Delete the last comment of the current user") + cmd.Flags().BoolVar(&opts.DeleteLastConfirmed, "yes", false, "Skip the delete confirmation prompt when --delete-last is provided") cmd.Flags().BoolVar(&opts.CreateIfNone, "create-if-none", false, "Create a new comment if no comments are found. Can be used only with --edit-last") return cmd diff --git a/pkg/cmd/issue/comment/comment_test.go b/pkg/cmd/issue/comment/comment_test.go index 794dafda4..adee53f7e 100644 --- a/pkg/cmd/issue/comment/comment_test.go +++ b/pkg/cmd/issue/comment/comment_test.go @@ -2,6 +2,7 @@ package comment import ( "bytes" + "errors" "fmt" "net/http" "os" @@ -31,11 +32,13 @@ func TestNewCmdComment(t *testing.T) { stdin string output shared.CommentableOptions wantsErr bool + isTTY bool }{ { name: "no arguments", input: "", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { @@ -46,6 +49,7 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -56,6 +60,7 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -66,6 +71,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "test", }, + isTTY: true, wantsErr: false, }, { @@ -77,6 +83,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "this is on standard input", }, + isTTY: true, wantsErr: false, }, { @@ -87,6 +94,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "a body from file", }, + isTTY: true, wantsErr: false, }, { @@ -97,6 +105,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeEditor, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -107,6 +116,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeWeb, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -118,6 +128,7 @@ func TestNewCmdComment(t *testing.T) { Body: "", EditLast: true, }, + isTTY: true, wantsErr: false, }, { @@ -130,42 +141,110 @@ func TestNewCmdComment(t *testing.T) { EditLast: true, CreateIfNone: true, }, + isTTY: true, wantsErr: false, }, + { + name: "delete last flag non-interactive", + input: "1 --delete-last", + isTTY: false, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation non-interactive", + input: "1 --delete-last --yes", + output: shared.CommentableOptions{ + DeleteLast: true, + DeleteLastConfirmed: true, + }, + isTTY: false, + wantsErr: false, + }, + { + name: "delete last flag interactive", + input: "1 --delete-last", + output: shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + }, + isTTY: true, + wantsErr: false, + }, + { + name: "delete last flag and pre-confirmation interactive", + input: "1 --delete-last --yes", + output: shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + isTTY: true, + wantsErr: false, + }, + { + name: "delete last flag and pre-confirmation with web flag", + input: "1 --delete-last --yes --web", + isTTY: true, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation with editor flag", + input: "1 --delete-last --yes --editor", + isTTY: true, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation with body flag", + input: "1 --delete-last --yes --body", + isTTY: true, + wantsErr: true, + }, + { + name: "delete pre-confirmation without delete last flag", + input: "1 --yes", + isTTY: true, + wantsErr: true, + }, { name: "body and body-file flags", input: "1 --body 'test' --body-file 'test-file.txt'", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor and web flags", input: "1 --editor --web", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor and body flags", input: "1 --editor --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "web and body flags", input: "1 --web --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor, web, and body flags", input: "1 --editor --web --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "create-if-none flag without edit-last", input: "1 --create-if-none", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, } @@ -173,9 +252,10 @@ func TestNewCmdComment(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ios, stdin, _, _ := iostreams.Test() - ios.SetStdoutTTY(true) - ios.SetStdinTTY(true) - ios.SetStderrTTY(true) + isTTY := tt.isTTY + ios.SetStdoutTTY(isTTY) + ios.SetStdinTTY(isTTY) + ios.SetStderrTTY(isTTY) if tt.stdin != "" { _, _ = stdin.WriteString(tt.stdin) @@ -211,6 +291,8 @@ func TestNewCmdComment(t *testing.T) { assert.Equal(t, tt.output.Interactive, gotOpts.Interactive) assert.Equal(t, tt.output.InputType, gotOpts.InputType) assert.Equal(t, tt.output.Body, gotOpts.Body) + assert.Equal(t, tt.output.DeleteLast, gotOpts.DeleteLast) + assert.Equal(t, tt.output.DeleteLastConfirmed, gotOpts.DeleteLastConfirmed) }) } } @@ -220,6 +302,7 @@ func Test_commentRun(t *testing.T) { name string input *shared.CommentableOptions emptyComments bool + comments api.Comments httpStubs func(*testing.T, *httpmock.Registry) stdout string stderr string @@ -255,6 +338,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "updating last comment with interactive editor succeeds if there are comments", @@ -331,6 +415,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "creating new comment with non-interactive editor succeeds", @@ -358,6 +443,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "updating last comment with non-interactive editor succeeds if there are comments", @@ -433,6 +519,117 @@ func Test_commentRun(t *testing.T) { }, stdout: "https://github.com/OWNER/REPO/issues/123#issuecomment-456\n", }, + { + name: "deleting last comment non-interactively without any comment", + input: &shared.CommentableOptions{ + Interactive: false, + DeleteLast: true, + }, + emptyComments: true, + wantsErr: true, + stdout: "no comments found for current user", + }, + { + name: "deleting last comment interactively without any comment", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + }, + emptyComments: true, + wantsErr: true, + stdout: "no comments found for current user", + }, + { + name: "deleting last comment non-interactively and pre-confirmed", + input: &shared.CommentableOptions{ + Interactive: false, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and pre-confirmed", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and confirmed", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "comment body" { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stdout: "! Deleted comments cannot be recovered.\n", + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and confirmation declined", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "comment body" { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + wantsErr: true, + stdout: "deletion not confirmed", + }, + { + name: "deleting last comment interactively and confirmed with long comment body", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "Lorem ipsum dolor sit amet, consectet lo..." { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "Lorem ipsum dolor sit amet, consectet lorem ipsum again"}, + }}, + wantsErr: false, + stdout: "! Deleted comments cannot be recovered.\n", + stderr: "Comment deleted\n", + }, } for _, tt := range tests { ios, _, stdout, stderr := iostreams.Test() @@ -458,6 +655,8 @@ func Test_commentRun(t *testing.T) { if tt.emptyComments { comments.Nodes = []api.Comment{} + } else if len(tt.comments.Nodes) > 0 { + comments = tt.comments } tt.input.RetrieveCommentable = func() (shared.Commentable, ghrepo.Interface, error) { @@ -472,6 +671,7 @@ func Test_commentRun(t *testing.T) { err := shared.CommentableRun(tt.input) if tt.wantsErr { assert.Error(t, err) + assert.Equal(t, tt.stderr, stderr.String()) return } assert.NoError(t, err) @@ -508,3 +708,15 @@ func mockCommentUpdate(t *testing.T, reg *httpmock.Registry) { }), ) } + +func mockCommentDelete(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation CommentDelete\b`), + httpmock.GraphQLMutation(` + { "data": { "deleteIssueComment": {} } }`, + func(inputs map[string]interface{}) { + assert.Equal(t, "id1", inputs["id"]) + }, + ), + ) +} diff --git a/pkg/cmd/pr/comment/comment.go b/pkg/cmd/pr/comment/comment.go index a2ab4bf9e..2eed7d353 100644 --- a/pkg/cmd/pr/comment/comment.go +++ b/pkg/cmd/pr/comment/comment.go @@ -16,6 +16,7 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*shared.CommentableOptions) err InteractiveEditSurvey: shared.CommentableInteractiveEditSurvey(f.Config, f.IOStreams), ConfirmSubmitSurvey: shared.CommentableConfirmSubmitSurvey(f.Prompter), ConfirmCreateIfNoneSurvey: shared.CommentableInteractiveCreateIfNoneSurvey(f.Prompter), + ConfirmDeleteLastComment: shared.CommentableConfirmDeleteLastComment(f.Prompter), OpenInBrowser: f.Browser.Browse, } @@ -43,7 +44,7 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*shared.CommentableOptions) err selector = args[0] } fields := []string{"id", "url"} - if opts.EditLast { + if opts.EditLast || opts.DeleteLast { fields = append(fields, "comments") } finder := shared.NewFinder(f) @@ -75,7 +76,9 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*shared.CommentableOptions) err cmd.Flags().StringVarP(&bodyFile, "body-file", "F", "", "Read body text from `file` (use \"-\" to read from standard input)") cmd.Flags().BoolP("editor", "e", false, "Skip prompts and open the text editor to write the body in") cmd.Flags().BoolP("web", "w", false, "Open the web browser to write the comment") - cmd.Flags().BoolVar(&opts.EditLast, "edit-last", false, "Edit the last comment of the same author") + cmd.Flags().BoolVar(&opts.EditLast, "edit-last", false, "Edit the last comment of the current user") + cmd.Flags().BoolVar(&opts.DeleteLast, "delete-last", false, "Delete the last comment of the current user") + cmd.Flags().BoolVar(&opts.DeleteLastConfirmed, "yes", false, "Skip the delete confirmation prompt when --delete-last is provided") cmd.Flags().BoolVar(&opts.CreateIfNone, "create-if-none", false, "Create a new comment if no comments are found. Can be used only with --edit-last") return cmd diff --git a/pkg/cmd/pr/comment/comment_test.go b/pkg/cmd/pr/comment/comment_test.go index 0941f2533..b9d8e153d 100644 --- a/pkg/cmd/pr/comment/comment_test.go +++ b/pkg/cmd/pr/comment/comment_test.go @@ -2,6 +2,7 @@ package comment import ( "bytes" + "errors" "fmt" "net/http" "os" @@ -31,6 +32,7 @@ func TestNewCmdComment(t *testing.T) { stdin string output shared.CommentableOptions wantsErr bool + isTTY bool }{ { name: "no arguments", @@ -40,12 +42,14 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { name: "two arguments", input: "1 2", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { @@ -56,6 +60,7 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -66,6 +71,7 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -76,6 +82,7 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -86,6 +93,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "test", }, + isTTY: true, wantsErr: false, }, { @@ -97,6 +105,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "this is on standard input", }, + isTTY: true, wantsErr: false, }, { @@ -107,6 +116,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "a body from file", }, + isTTY: true, wantsErr: false, }, { @@ -117,6 +127,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeEditor, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -127,6 +138,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeWeb, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -138,6 +150,7 @@ func TestNewCmdComment(t *testing.T) { Body: "", EditLast: true, }, + isTTY: true, wantsErr: false, }, { @@ -150,42 +163,110 @@ func TestNewCmdComment(t *testing.T) { EditLast: true, CreateIfNone: true, }, + isTTY: true, wantsErr: false, }, + { + name: "delete last flag non-interactive", + input: "1 --delete-last", + isTTY: false, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation non-interactive", + input: "1 --delete-last --yes", + output: shared.CommentableOptions{ + DeleteLast: true, + DeleteLastConfirmed: true, + }, + isTTY: false, + wantsErr: false, + }, + { + name: "delete last flag interactive", + input: "1 --delete-last", + output: shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + }, + isTTY: true, + wantsErr: false, + }, + { + name: "delete last flag and pre-confirmation interactive", + input: "1 --delete-last --yes", + output: shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + isTTY: true, + wantsErr: false, + }, + { + name: "delete last flag and pre-confirmation with web flag", + input: "1 --delete-last --yes --web", + isTTY: true, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation with editor flag", + input: "1 --delete-last --yes --editor", + isTTY: true, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation with body flag", + input: "1 --delete-last --yes --body", + isTTY: true, + wantsErr: true, + }, + { + name: "delete pre-confirmation without delete last flag", + input: "1 --yes", + isTTY: true, + wantsErr: true, + }, { name: "body and body-file flags", input: "1 --body 'test' --body-file 'test-file.txt'", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor and web flags", input: "1 --editor --web", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor and body flags", input: "1 --editor --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "web and body flags", input: "1 --web --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor, web, and body flags", input: "1 --editor --web --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "create-if-none flag without edit-last", input: "1 --create-if-none", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, } @@ -193,9 +274,10 @@ func TestNewCmdComment(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ios, stdin, _, _ := iostreams.Test() - ios.SetStdoutTTY(true) - ios.SetStdinTTY(true) - ios.SetStderrTTY(true) + isTTY := tt.isTTY + ios.SetStdoutTTY(isTTY) + ios.SetStdinTTY(isTTY) + ios.SetStderrTTY(isTTY) if tt.stdin != "" { _, _ = stdin.WriteString(tt.stdin) @@ -231,6 +313,8 @@ func TestNewCmdComment(t *testing.T) { assert.Equal(t, tt.output.Interactive, gotOpts.Interactive) assert.Equal(t, tt.output.InputType, gotOpts.InputType) assert.Equal(t, tt.output.Body, gotOpts.Body) + assert.Equal(t, tt.output.DeleteLast, gotOpts.DeleteLast) + assert.Equal(t, tt.output.DeleteLastConfirmed, gotOpts.DeleteLastConfirmed) }) } } @@ -240,6 +324,7 @@ func Test_commentRun(t *testing.T) { name string input *shared.CommentableOptions emptyComments bool + comments api.Comments httpStubs func(*testing.T, *httpmock.Registry) stdout string stderr string @@ -274,6 +359,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "updating last comment with interactive editor succeeds if there are comments", @@ -350,6 +436,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "creating new comment with non-interactive editor succeeds", @@ -377,6 +464,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "updating last comment with non-interactive editor succeeds if there are comments", @@ -451,6 +539,117 @@ func Test_commentRun(t *testing.T) { }, stdout: "https://github.com/OWNER/REPO/pull/123#issuecomment-456\n", }, + { + name: "deleting last comment non-interactively without any comment", + input: &shared.CommentableOptions{ + Interactive: false, + DeleteLast: true, + }, + emptyComments: true, + wantsErr: true, + stdout: "no comments found for current user", + }, + { + name: "deleting last comment interactively without any comment", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + }, + emptyComments: true, + wantsErr: true, + stdout: "no comments found for current user", + }, + { + name: "deleting last comment non-interactively and pre-confirmed", + input: &shared.CommentableOptions{ + Interactive: false, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and pre-confirmed", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and confirmed", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "comment body" { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stdout: "! Deleted comments cannot be recovered.\n", + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and confirmation declined", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "comment body" { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + wantsErr: true, + stdout: "deletion not confirmed", + }, + { + name: "deleting last comment interactively and confirmed with long comment body", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "Lorem ipsum dolor sit amet, consectet lo..." { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "Lorem ipsum dolor sit amet, consectet lorem ipsum again"}, + }}, + wantsErr: false, + stdout: "! Deleted comments cannot be recovered.\n", + stderr: "Comment deleted\n", + }, } for _, tt := range tests { ios, _, stdout, stderr := iostreams.Test() @@ -475,6 +674,8 @@ func Test_commentRun(t *testing.T) { }} if tt.emptyComments { comments.Nodes = []api.Comment{} + } else if len(tt.comments.Nodes) > 0 { + comments = tt.comments } tt.input.RetrieveCommentable = func() (shared.Commentable, ghrepo.Interface, error) { @@ -489,6 +690,7 @@ func Test_commentRun(t *testing.T) { err := shared.CommentableRun(tt.input) if tt.wantsErr { assert.Error(t, err) + assert.Equal(t, tt.stderr, stderr.String()) return } assert.NoError(t, err) @@ -524,3 +726,15 @@ func mockCommentUpdate(t *testing.T, reg *httpmock.Registry) { }), ) } + +func mockCommentDelete(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation CommentDelete\b`), + httpmock.GraphQLMutation(` + { "data": { "deleteIssueComment": {} } }`, + func(inputs map[string]interface{}) { + assert.Equal(t, "id1", inputs["id"]) + }, + ), + ) +} diff --git a/pkg/cmd/pr/shared/commentable.go b/pkg/cmd/pr/shared/commentable.go index f909c7559..015d84a4b 100644 --- a/pkg/cmd/pr/shared/commentable.go +++ b/pkg/cmd/pr/shared/commentable.go @@ -18,6 +18,7 @@ import ( ) var errNoUserComments = errors.New("no comments found for current user") +var errDeleteNotConfirmed = errors.New("deletion not confirmed") type InputType int @@ -41,11 +42,14 @@ type CommentableOptions struct { InteractiveEditSurvey func(string) (string, error) ConfirmSubmitSurvey func() (bool, error) ConfirmCreateIfNoneSurvey func() (bool, error) + ConfirmDeleteLastComment func(string) (bool, error) OpenInBrowser func(string) error Interactive bool InputType InputType Body string EditLast bool + DeleteLast bool + DeleteLastConfirmed bool CreateIfNone bool Quiet bool Host string @@ -74,6 +78,21 @@ func CommentablePreRun(cmd *cobra.Command, opts *CommentableOptions) error { return cmdutil.FlagErrorf("`--create-if-none` can only be used with `--edit-last`") } + if opts.DeleteLastConfirmed && !opts.DeleteLast { + return cmdutil.FlagErrorf("`--yes` should only be used with `--delete-last`") + } + + if opts.DeleteLast { + if inputFlags > 0 { + return cmdutil.FlagErrorf("should not provide comment body when using `--delete-last`") + } + if opts.IO.CanPrompt() || opts.DeleteLastConfirmed { + opts.Interactive = opts.IO.CanPrompt() + return nil + } + return cmdutil.FlagErrorf("should provide `--yes` to confirm deletion in non-interactive mode") + } + if inputFlags == 0 { if !opts.IO.CanPrompt() { return cmdutil.FlagErrorf("flags required when not running interactively") @@ -92,6 +111,9 @@ func CommentableRun(opts *CommentableOptions) error { return err } opts.Host = repo.RepoHost() + if opts.DeleteLast { + return deleteComment(commentable, opts) + } // Create new comment, bail before complexities of updating the last comment if !opts.EditLast { @@ -236,6 +258,53 @@ func updateComment(commentable Commentable, opts *CommentableOptions) error { return nil } +func deleteComment(commentable Commentable, opts *CommentableOptions) error { + comments := commentable.CurrentUserComments() + if len(comments) == 0 { + return errNoUserComments + } + + lastComment := comments[len(comments)-1] + + cs := opts.IO.ColorScheme() + + if opts.Interactive && !opts.DeleteLastConfirmed { + // This is not an ideal way of truncating a random string that may + // contain emojis or other kind of wide chars. + truncated := lastComment.Body + if len(lastComment.Body) > 40 { + truncated = lastComment.Body[:40] + "..." + } + + fmt.Fprintf(opts.IO.Out, "%s Deleted comments cannot be recovered.\n", cs.WarningIcon()) + ok, err := opts.ConfirmDeleteLastComment(truncated) + if err != nil { + return err + } + if !ok { + return errDeleteNotConfirmed + } + } + + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + + apiClient := api.NewClientFromHTTP(httpClient) + params := api.CommentDeleteInput{CommentId: lastComment.Identifier()} + deletionErr := api.CommentDelete(apiClient, opts.Host, params) + if deletionErr != nil { + return deletionErr + } + + if !opts.Quiet { + fmt.Fprintln(opts.IO.ErrOut, "Comment deleted") + } + + return nil +} + func CommentableConfirmSubmitSurvey(p Prompt) func() (bool, error) { return func() (bool, error) { return p.Confirm("Submit?", true) @@ -271,6 +340,12 @@ func CommentableEditSurvey(cf func() (gh.Config, error), io *iostreams.IOStreams } } +func CommentableConfirmDeleteLastComment(p Prompt) func(string) (bool, error) { + return func(body string) (bool, error) { + return p.Confirm(fmt.Sprintf("Delete the comment: %q?", body), true) + } +} + func waitForEnter(r io.Reader) error { scanner := bufio.NewScanner(r) scanner.Scan() From 3bcf9758ad24a3c8caf000c91b42a332c132877c Mon Sep 17 00:00:00 2001 From: William Martin Date: Thu, 17 Apr 2025 22:13:31 +0200 Subject: [PATCH 43/51] Feature detect v1 projects on pr view --- pkg/cmd/pr/shared/finder.go | 29 +++++++++++-- pkg/cmd/pr/view/view.go | 5 +++ pkg/cmd/pr/view/view_test.go | 81 ++++++++++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 3 deletions(-) diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index 6d36ef816..b509f946c 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -16,6 +16,7 @@ import ( ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" fd "github.com/cli/cli/v2/internal/featuredetection" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/cmdutil" o "github.com/cli/cli/v2/pkg/option" @@ -79,6 +80,10 @@ func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interfac return finder } +func ResetRunCommandFinder() { + runCommandFinder = nil +} + type FindOptions struct { // Selector can be a number with optional `#` prefix, a branch name with optional `:` prefix, or // a PR URL. @@ -89,6 +94,8 @@ type FindOptions struct { BaseBranch string // States lists the possible PR states to scope the PR-for-branch lookup to. States []string + + Detector fd.Detector } func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { @@ -193,9 +200,11 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err fields.AddValues([]string{"id", "number"}) // for additional preload queries below if fields.Contains("isInMergeQueue") || fields.Contains("isMergeQueueEnabled") { - cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) - detector := fd.NewDetector(cachedClient, f.baseRefRepo.RepoHost()) - prFeatures, err := detector.PullRequestFeatures() + if opts.Detector == nil { + cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) + opts.Detector = fd.NewDetector(cachedClient, f.baseRefRepo.RepoHost()) + } + prFeatures, err := opts.Detector.PullRequestFeatures() if err != nil { return nil, nil, err } @@ -211,6 +220,20 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err fields.Remove("projectItems") } + // TODO projectsV1Deprecation + // Remove this block + // When removing this, remember to remove `projectCards` from the list of default fields in pr/view.go + if fields.Contains("projectCards") { + if opts.Detector == nil { + cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) + opts.Detector = fd.NewDetector(cachedClient, f.baseRefRepo.RepoHost()) + } + + if opts.Detector.ProjectsV1() == gh.ProjectsV1Unsupported { + fields.Remove("projectCards") + } + } + var pr *api.PullRequest if f.prNumber > 0 { if numberFieldOnly { diff --git a/pkg/cmd/pr/view/view.go b/pkg/cmd/pr/view/view.go index 997f74d87..8a39d1134 100644 --- a/pkg/cmd/pr/view/view.go +++ b/pkg/cmd/pr/view/view.go @@ -10,6 +10,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/browser" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/text" "github.com/cli/cli/v2/pkg/cmd/pr/shared" @@ -22,6 +23,9 @@ import ( type ViewOptions struct { IO *iostreams.IOStreams Browser browser.Browser + // TODO projectsV1Deprecation + // Remove this detector since it is only used for test validation. + Detector fd.Detector Finder shared.PRFinder Exporter cmdutil.Exporter @@ -89,6 +93,7 @@ func viewRun(opts *ViewOptions) error { findOptions := shared.FindOptions{ Selector: opts.SelectorArg, Fields: defaultFields, + Detector: opts.Detector, } if opts.BrowserMode { findOptions.Fields = []string{"url"} diff --git a/pkg/cmd/pr/view/view_test.go b/pkg/cmd/pr/view/view_test.go index e7f572c76..3a2a87e5c 100644 --- a/pkg/cmd/pr/view/view_test.go +++ b/pkg/cmd/pr/view/view_test.go @@ -12,6 +12,7 @@ import ( "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/browser" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/run" "github.com/cli/cli/v2/pkg/cmd/pr/shared" @@ -175,6 +176,9 @@ func runCommand(rt http.RoundTripper, branch string, isTTY bool, cli string) (*t factory := &cmdutil.Factory{ IOStreams: ios, Browser: browser, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: rt}, nil + }, } cmd := NewCmdView(factory, nil) @@ -398,6 +402,8 @@ func TestPRView_Preview_nontty(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + t.Cleanup(shared.ResetRunCommandFinder) + http := &httpmock.Registry{} defer http.Verify(t) @@ -602,6 +608,8 @@ func TestPRView_Preview(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + t.Cleanup(shared.ResetRunCommandFinder) + http := &httpmock.Registry{} defer http.Verify(t) @@ -846,6 +854,8 @@ func TestPRView_nontty_Comments(t *testing.T) { } for name, tt := range tests { t.Run(name, func(t *testing.T) { + t.Cleanup(shared.ResetRunCommandFinder) + http := &httpmock.Registry{} defer http.Verify(t) @@ -869,3 +879,74 @@ func TestPRView_nontty_Comments(t *testing.T) { }) } } + +// TODO projectsV1Deprecation +// Remove this test. +func TestProjectsV1Deprecation(t *testing.T) { + t.Run("when projects v1 is supported, is included in query", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Register( + httpmock.GraphQL(`projectCards`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + f := &cmdutil.Factory{ + IOStreams: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + } + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = viewRun(&ViewOptions{ + IO: ios, + Finder: shared.NewFinder(f), + Detector: &fd.EnabledDetectorMock{}, + + SelectorArg: "https://github.com/cli/cli/pull/123", + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, is not included in query", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Exclude( + t, + httpmock.GraphQL(`projectCards`), + ) + + f := &cmdutil.Factory{ + IOStreams: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + } + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = viewRun(&ViewOptions{ + IO: ios, + Finder: shared.NewFinder(f), + Detector: &fd.DisabledDetectorMock{}, + + SelectorArg: "https://github.com/cli/cli/pull/123", + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) +} From 284880c21edec4a913ce9d33705531fafeb14ee8 Mon Sep 17 00:00:00 2001 From: "Babak K. Shandiz" Date: Thu, 1 May 2025 20:22:43 +0100 Subject: [PATCH 44/51] Fix `StatusJSONResponse` usage (#10810) * Fix `StatusJSONResponse` usage Signed-off-by: Babak K. Shandiz * Replace `assert` with `require` Signed-off-by: Babak K. Shandiz * Improve assertion against errors Signed-off-by: Babak K. Shandiz * Add `JSONErrorResponse` helper func Signed-off-by: Babak K. Shandiz * Use `httpmock.JSONErrorResponse` Signed-off-by: Babak K. Shandiz * Replace `StatusJSONResponse` to `JSONErrorResponse` for better readibility Signed-off-by: Babak K. Shandiz * Fix improper use of `StatsJSONResponse` Signed-off-by: Babak K. Shandiz --------- Signed-off-by: Babak K. Shandiz --- pkg/cmd/gist/delete/delete_test.go | 50 ++++++++++++----------- pkg/cmd/gpg-key/delete/delete_test.go | 2 +- pkg/cmd/repo/autolink/delete/http_test.go | 22 +++++----- pkg/cmd/run/watch/watch_test.go | 2 +- pkg/httpmock/stub.go | 11 +++++ 5 files changed, 52 insertions(+), 35 deletions(-) diff --git a/pkg/cmd/gist/delete/delete_test.go b/pkg/cmd/gist/delete/delete_test.go index 24ca2bb33..2c4df8d8d 100644 --- a/pkg/cmd/gist/delete/delete_test.go +++ b/pkg/cmd/gist/delete/delete_test.go @@ -18,6 +18,7 @@ import ( ghAPI "github.com/cli/go-gh/v2/pkg/api" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewCmdDelete(t *testing.T) { @@ -327,11 +328,12 @@ func Test_deleteRun(t *testing.T) { func Test_gistDelete(t *testing.T) { tests := []struct { - name string - httpStubs func(*httpmock.Registry) - hostname string - gistID string - wantErr error + name string + httpStubs func(*httpmock.Registry) + hostname string + gistID string + wantErr error + wantErrString string }{ { name: "successful delete", @@ -343,36 +345,34 @@ func Test_gistDelete(t *testing.T) { }, hostname: "github.com", gistID: "1234", - wantErr: nil, }, { - name: "when an gist is not found, it returns a NotFoundError", + name: "when a gist is not found, it returns a NotFoundError", httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.REST("DELETE", "gists/1234"), httpmock.StatusStringResponse(404, "{}"), ) }, - hostname: "github.com", - gistID: "1234", - wantErr: shared.NotFoundErr, + hostname: "github.com", + gistID: "1234", + wantErr: shared.NotFoundErr, // To make sure we return the pre-defined error instance. + wantErrString: "not found", }, { name: "when there is a non-404 error deleting the gist, that error is returned", httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.REST("DELETE", "gists/1234"), - httpmock.StatusJSONResponse(500, `{"message": "arbitrary error"}`), + httpmock.JSONErrorResponse(500, ghAPI.HTTPError{ + StatusCode: 500, + Message: "arbitrary error", + }), ) }, - hostname: "github.com", - gistID: "1234", - wantErr: api.HTTPError{ - HTTPError: &ghAPI.HTTPError{ - StatusCode: 500, - Message: "arbitrary error", - }, - }, + hostname: "github.com", + gistID: "1234", + wantErrString: "HTTP 500: arbitrary error (https://api.github.com/gists/1234)", }, } @@ -383,12 +383,16 @@ func Test_gistDelete(t *testing.T) { client := api.NewClientFromHTTP(&http.Client{Transport: reg}) err := deleteGist(client, tt.hostname, tt.gistID) - if tt.wantErr != nil { - assert.ErrorAs(t, err, &tt.wantErr) + if tt.wantErrString == "" && tt.wantErr == nil { + require.NoError(t, err) } else { - assert.NoError(t, err) + if tt.wantErrString != "" { + require.EqualError(t, err, tt.wantErrString) + } + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + } } - }) } } diff --git a/pkg/cmd/gpg-key/delete/delete_test.go b/pkg/cmd/gpg-key/delete/delete_test.go index 115e72db2..dc730b100 100644 --- a/pkg/cmd/gpg-key/delete/delete_test.go +++ b/pkg/cmd/gpg-key/delete/delete_test.go @@ -177,7 +177,7 @@ func Test_deleteRun(t *testing.T) { opts: DeleteOptions{KeyID: "ABC123", Confirmed: true}, httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", "user/gpg_keys"), httpmock.StatusStringResponse(200, keysResp)) - reg.Register(httpmock.REST("DELETE", "user/gpg_keys/123"), httpmock.StatusJSONResponse(404, api.HTTPError{ + reg.Register(httpmock.REST("DELETE", "user/gpg_keys/123"), httpmock.JSONErrorResponse(404, api.HTTPError{ StatusCode: 404, Message: "GPG key 123 not found", })) diff --git a/pkg/cmd/repo/autolink/delete/http_test.go b/pkg/cmd/repo/autolink/delete/http_test.go index a2676178d..a0aec5e13 100644 --- a/pkg/cmd/repo/autolink/delete/http_test.go +++ b/pkg/cmd/repo/autolink/delete/http_test.go @@ -7,6 +7,7 @@ import ( "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/httpmock" + "github.com/cli/go-gh/v2/pkg/api" "github.com/stretchr/testify/require" ) @@ -14,10 +15,10 @@ func TestAutolinkDeleter_Delete(t *testing.T) { repo := ghrepo.New("OWNER", "REPO") tests := []struct { - name string - id string - stubStatus int - stubRespJSON string + name string + id string + stubStatus int + stubResp any expectErr bool expectedErrMsg string @@ -31,17 +32,18 @@ func TestAutolinkDeleter_Delete(t *testing.T) { name: "404 repo or autolink not found", id: "123", stubStatus: http.StatusNotFound, - stubRespJSON: `{}`, // API response not used in output expectErr: true, expectedErrMsg: "error deleting autolink: HTTP 404: Perhaps you are missing admin rights to the repository? (https://api.github.com/repos/OWNER/REPO/autolinks/123)", }, { - name: "500 unexpected error", - id: "123", - stubRespJSON: `{"messsage": "arbitrary error"}`, + name: "500 unexpected error", + id: "123", + stubResp: api.HTTPError{ + Message: "arbitrary error", + }, stubStatus: http.StatusInternalServerError, expectErr: true, - expectedErrMsg: "HTTP 500 (https://api.github.com/repos/OWNER/REPO/autolinks/123)", + expectedErrMsg: "HTTP 500: arbitrary error (https://api.github.com/repos/OWNER/REPO/autolinks/123)", }, } @@ -53,7 +55,7 @@ func TestAutolinkDeleter_Delete(t *testing.T) { http.MethodDelete, fmt.Sprintf("repos/%s/%s/autolinks/%s", repo.RepoOwner(), repo.RepoName(), tt.id), ), - httpmock.StatusJSONResponse(tt.stubStatus, tt.stubRespJSON), + httpmock.StatusJSONResponse(tt.stubStatus, tt.stubResp), ) defer reg.Verify(t) diff --git a/pkg/cmd/run/watch/watch_test.go b/pkg/cmd/run/watch/watch_test.go index d42e8d3d8..49e56217b 100644 --- a/pkg/cmd/run/watch/watch_test.go +++ b/pkg/cmd/run/watch/watch_test.go @@ -316,7 +316,7 @@ func TestWatchRun(t *testing.T) { ) reg.Register( httpmock.REST("GET", "repos/OWNER/REPO/actions/runs/1234"), - httpmock.StatusJSONResponse(404, api.HTTPError{ + httpmock.JSONErrorResponse(404, api.HTTPError{ StatusCode: 404, Message: "run 1234 not found", }), diff --git a/pkg/httpmock/stub.go b/pkg/httpmock/stub.go index 745c12417..3b03ae718 100644 --- a/pkg/httpmock/stub.go +++ b/pkg/httpmock/stub.go @@ -9,6 +9,8 @@ import ( "os" "regexp" "strings" + + "github.com/cli/go-gh/v2/pkg/api" ) type Matcher func(req *http.Request) bool @@ -161,6 +163,9 @@ func JSONResponse(body interface{}) Responder { } } +// StatusJSONResponse turns the given argument into a JSON response. +// +// The argument is not meant to be a JSON string, unless it's intentional. func StatusJSONResponse(status int, body interface{}) Responder { return func(req *http.Request) (*http.Response, error) { b, _ := json.Marshal(body) @@ -171,6 +176,12 @@ func StatusJSONResponse(status int, body interface{}) Responder { } } +// JSONErrorResponse is a type-safe helper to avoid confusion around the +// provided argument. +func JSONErrorResponse(status int, err api.HTTPError) Responder { + return StatusJSONResponse(status, err) +} + func FileResponse(filename string) Responder { return func(req *http.Request) (*http.Response, error) { f, err := os.Open(filename) From 64370ce73e6774cd5c7ec912c4cafbd513f2af73 Mon Sep 17 00:00:00 2001 From: William Martin Date: Fri, 2 May 2025 14:41:24 +0200 Subject: [PATCH 45/51] Cleanup run command stubbed finders in tests --- pkg/cmd/pr/checkout/checkout_test.go | 26 +++++------ pkg/cmd/pr/close/close_test.go | 14 +++--- pkg/cmd/pr/merge/merge_test.go | 66 ++++++++++++++-------------- pkg/cmd/pr/ready/ready_test.go | 10 ++--- pkg/cmd/pr/reopen/reopen_test.go | 8 ++-- pkg/cmd/pr/review/review_test.go | 8 ++-- pkg/cmd/pr/shared/finder.go | 30 ++++++++----- pkg/cmd/pr/view/view_test.go | 22 ++++------ 8 files changed, 93 insertions(+), 91 deletions(-) diff --git a/pkg/cmd/pr/checkout/checkout_test.go b/pkg/cmd/pr/checkout/checkout_test.go index 40917fd76..496139423 100644 --- a/pkg/cmd/pr/checkout/checkout_test.go +++ b/pkg/cmd/pr/checkout/checkout_test.go @@ -518,7 +518,7 @@ func TestPRCheckout_sameRepo(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") - finder := shared.RunCommandFinder("123", pr, baseRepo) + finder := shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}) cs, cmdTeardown := run.Stub() @@ -539,7 +539,7 @@ func TestPRCheckout_existingBranch(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -570,7 +570,7 @@ func TestPRCheckout_differentRepo_remoteExists(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO", "hubot/REPO:feature") - finder := shared.RunCommandFinder("123", pr, baseRepo) + finder := shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}) cs, cmdTeardown := run.Stub() @@ -590,7 +590,7 @@ func TestPRCheckout_differentRepo(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - finder := shared.RunCommandFinder("123", pr, baseRepo) + finder := shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}) cs, cmdTeardown := run.Stub() @@ -613,7 +613,7 @@ func TestPRCheckout_differentRepoForce(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - finder := shared.RunCommandFinder("123", pr, baseRepo) + finder := shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}) cs, cmdTeardown := run.Stub() @@ -636,7 +636,7 @@ func TestPRCheckout_differentRepo_existingBranch(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -655,7 +655,7 @@ func TestPRCheckout_detachedHead(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -674,7 +674,7 @@ func TestPRCheckout_differentRepo_currentBranch(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -693,7 +693,7 @@ func TestPRCheckout_differentRepo_invalidBranchName(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO", "hubot/REPO:-foo") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) _, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -711,7 +711,7 @@ func TestPRCheckout_maintainerCanModify(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") pr.MaintainerCanModify = true - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -732,7 +732,7 @@ func TestPRCheckout_recurseSubmodules(t *testing.T) { http := &httpmock.Registry{} baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -753,7 +753,7 @@ func TestPRCheckout_force(t *testing.T) { http := &httpmock.Registry{} baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -774,7 +774,7 @@ func TestPRCheckout_detach(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) diff --git a/pkg/cmd/pr/close/close_test.go b/pkg/cmd/pr/close/close_test.go index 959af0e04..57ee0f0e6 100644 --- a/pkg/cmd/pr/close/close_test.go +++ b/pkg/cmd/pr/close/close_test.go @@ -110,7 +110,7 @@ func TestPrClose(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation PullRequestClose\b`), @@ -133,7 +133,7 @@ func TestPrClose_alreadyClosed(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") pr.State = "CLOSED" pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) output, err := runCommand(http, true, "96") assert.NoError(t, err) @@ -147,7 +147,7 @@ func TestPrClose_deleteBranch_sameRepo(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:blueberries") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation PullRequestClose\b`), @@ -181,7 +181,7 @@ func TestPrClose_deleteBranch_crossRepo(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO", "hubot/REPO:blueberries") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation PullRequestClose\b`), @@ -213,7 +213,7 @@ func TestPrClose_deleteBranch_sameBranch(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO:main", "OWNER/REPO:trunk") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation PullRequestClose\b`), @@ -248,7 +248,7 @@ func TestPrClose_deleteBranch_notInGitRepo(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO:main", "OWNER/REPO:trunk") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation PullRequestClose\b`), @@ -282,7 +282,7 @@ func TestPrClose_withComment(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation CommentCreate\b`), diff --git a/pkg/cmd/pr/merge/merge_test.go b/pkg/cmd/pr/merge/merge_test.go index f1c2e37fe..4ca8c5d06 100644 --- a/pkg/cmd/pr/merge/merge_test.go +++ b/pkg/cmd/pr/merge/merge_test.go @@ -307,7 +307,7 @@ func TestPrMerge(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -348,7 +348,7 @@ func TestPrMerge_blocked(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -379,7 +379,7 @@ func TestPrMerge_dirty(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -413,7 +413,7 @@ func TestPrMerge_nontty(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -451,7 +451,7 @@ func TestPrMerge_editMessage_nontty(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -490,7 +490,7 @@ func TestPrMerge_withRepoFlag(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -529,7 +529,7 @@ func TestPrMerge_withMatchCommitHeadFlag(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -570,7 +570,7 @@ func TestPrMerge_withAuthorFlag(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -612,7 +612,7 @@ func TestPrMerge_deleteBranch(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -663,7 +663,7 @@ func TestPrMerge_deleteBranch_mergeQueue(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -686,7 +686,7 @@ func TestPrMerge_deleteBranch_nonDefault(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -737,7 +737,7 @@ func TestPrMerge_deleteBranch_onlyLocally(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -785,7 +785,7 @@ func TestPrMerge_deleteBranch_checkoutNewBranch(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -836,7 +836,7 @@ func TestPrMerge_deleteNonCurrentBranch(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "blueberries", &api.PullRequest{ ID: "PR_10", @@ -893,7 +893,7 @@ func Test_nonDivergingPullRequest(t *testing.T) { } stubCommit(pr, "COMMITSHA1") - shared.RunCommandFinder("", pr, baseRepo("OWNER", "REPO", "main")) + shared.StubFinderForRunCommandStyleTests(t, "", pr, baseRepo("OWNER", "REPO", "main")) http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), @@ -933,7 +933,7 @@ func Test_divergingPullRequestWarning(t *testing.T) { } stubCommit(pr, "COMMITSHA1") - shared.RunCommandFinder("", pr, baseRepo("OWNER", "REPO", "main")) + shared.StubFinderForRunCommandStyleTests(t, "", pr, baseRepo("OWNER", "REPO", "main")) http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), @@ -964,7 +964,7 @@ func Test_pullRequestWithoutCommits(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -1003,7 +1003,7 @@ func TestPrMerge_rebase(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "2", &api.PullRequest{ ID: "THE-ID", @@ -1044,7 +1044,7 @@ func TestPrMerge_squash(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "3", &api.PullRequest{ ID: "THE-ID", @@ -1084,7 +1084,7 @@ func TestPrMerge_alreadyMerged(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "4", &api.PullRequest{ ID: "THE-ID", @@ -1129,7 +1129,7 @@ func TestPrMerge_alreadyMerged_withMergeStrategy(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "4", &api.PullRequest{ ID: "THE-ID", @@ -1159,7 +1159,7 @@ func TestPrMerge_alreadyMerged_withMergeStrategy_TTY(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "4", &api.PullRequest{ ID: "THE-ID", @@ -1200,7 +1200,7 @@ func TestPrMerge_alreadyMerged_withMergeStrategy_crossRepo(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "4", &api.PullRequest{ ID: "THE-ID", @@ -1239,7 +1239,7 @@ func TestPRMergeTTY(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "THE-ID", @@ -1305,7 +1305,7 @@ func TestPRMergeTTY_withDeleteBranch(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "THE-ID", @@ -1468,7 +1468,7 @@ func TestPRMergeEmptyStrategyNonTTY(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1495,7 +1495,7 @@ func TestPRTTY_cancelled(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ID: "THE-ID", Number: 123, Title: "title", MergeStateStatus: "CLEAN"}, ghrepo.New("OWNER", "REPO"), @@ -1679,7 +1679,7 @@ func TestPrInMergeQueue(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1710,7 +1710,7 @@ func TestPrAddToMergeQueueWithMergeMethod(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1748,7 +1748,7 @@ func TestPrAddToMergeQueueClean(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1788,7 +1788,7 @@ func TestPrAddToMergeQueueBlocked(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1828,7 +1828,7 @@ func TestPrAddToMergeQueueAdmin(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1897,7 +1897,7 @@ func TestPrAddToMergeQueueAdminWithMergeStrategy(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", diff --git a/pkg/cmd/pr/ready/ready_test.go b/pkg/cmd/pr/ready/ready_test.go index 9046ab3ac..5a6053a17 100644 --- a/pkg/cmd/pr/ready/ready_test.go +++ b/pkg/cmd/pr/ready/ready_test.go @@ -124,7 +124,7 @@ func TestPRReady(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "OPEN", @@ -149,7 +149,7 @@ func TestPRReady_alreadyReady(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "OPEN", @@ -166,7 +166,7 @@ func TestPRReadyUndo(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "OPEN", @@ -191,7 +191,7 @@ func TestPRReadyUndo_alreadyDraft(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "OPEN", @@ -208,7 +208,7 @@ func TestPRReady_closed(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "CLOSED", diff --git a/pkg/cmd/pr/reopen/reopen_test.go b/pkg/cmd/pr/reopen/reopen_test.go index 856e19172..9fb3702c0 100644 --- a/pkg/cmd/pr/reopen/reopen_test.go +++ b/pkg/cmd/pr/reopen/reopen_test.go @@ -53,7 +53,7 @@ func TestPRReopen(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "CLOSED", @@ -78,7 +78,7 @@ func TestPRReopen_alreadyOpen(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "OPEN", @@ -95,7 +95,7 @@ func TestPRReopen_alreadyMerged(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "MERGED", @@ -112,7 +112,7 @@ func TestPRReopen_withComment(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "CLOSED", diff --git a/pkg/cmd/pr/review/review_test.go b/pkg/cmd/pr/review/review_test.go index f9e00c3b8..684617ca9 100644 --- a/pkg/cmd/pr/review/review_test.go +++ b/pkg/cmd/pr/review/review_test.go @@ -235,7 +235,7 @@ func TestPRReview(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", &api.PullRequest{ID: "THE-ID"}, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ID: "THE-ID"}, ghrepo.New("OWNER", "REPO")) http.Register( httpmock.GraphQL(`mutation PullRequestReviewAdd\b`), @@ -261,7 +261,7 @@ func TestPRReview_interactive(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) http.Register( httpmock.GraphQL(`mutation PullRequestReviewAdd\b`), @@ -293,7 +293,7 @@ func TestPRReview_interactive_no_body(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) pm := &prompter.PrompterMock{ SelectFunc: func(_, _ string, _ []string) (int, error) { return 2, nil }, @@ -308,7 +308,7 @@ func TestPRReview_interactive_blank_approve(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) http.Register( httpmock.GraphQL(`mutation PullRequestReviewAdd\b`), diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index b509f946c..04e8baf2c 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -10,6 +10,7 @@ import ( "sort" "strconv" "strings" + "testing" "time" "github.com/cli/cli/v2/api" @@ -55,9 +56,9 @@ type finder struct { } func NewFinder(factory *cmdutil.Factory) PRFinder { - if runCommandFinder != nil { - f := runCommandFinder - runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")} + if finderForRunCommandStyleTests != nil { + f := finderForRunCommandStyleTests + finderForRunCommandStyleTests = &mockFinder{err: errors.New("you must use StubFinderForRunCommandStyleTests to stub PR lookups")} return f } @@ -71,17 +72,24 @@ func NewFinder(factory *cmdutil.Factory) PRFinder { } } -var runCommandFinder PRFinder +var finderForRunCommandStyleTests PRFinder -// RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests. -func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { +// StubFinderForRunCommandStyleTests is the NewMockFinder substitute to be used ONLY in runCommand-style tests. +func StubFinderForRunCommandStyleTests(t *testing.T, selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { + // Create a new mock finder and override the "runCommandFinder" variable so that calls to + // NewFinder() will return this mock. This is a bad pattern, and a result of old style runCommand + // tests that would ideally be replaced. The reason we need to do this is that the runCommand style tests + // construct the cobra command via NewCmd* functions, and then Execute them directly, providing no opportunity + // to inject a test double unless it's on the factory, which finder never is, because only PR commands need it. finder := NewMockFinder(selector, pr, repo) - runCommandFinder = finder - return finder -} + finderForRunCommandStyleTests = finder -func ResetRunCommandFinder() { - runCommandFinder = nil + // Ensure that at the end of the test, we reset the "runCommandFinder" variable so that tests are isolated, + // at least if they are run sequentially. + t.Cleanup(func() { + finderForRunCommandStyleTests = nil + }) + return finder } type FindOptions struct { diff --git a/pkg/cmd/pr/view/view_test.go b/pkg/cmd/pr/view/view_test.go index 3a2a87e5c..ec1691305 100644 --- a/pkg/cmd/pr/view/view_test.go +++ b/pkg/cmd/pr/view/view_test.go @@ -402,14 +402,12 @@ func TestPRView_Preview_nontty(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - t.Cleanup(shared.ResetRunCommandFinder) - http := &httpmock.Registry{} defer http.Verify(t) pr, err := prFromFixtures(tc.fixtures) require.NoError(t, err) - shared.RunCommandFinder("12", pr, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "12", pr, ghrepo.New("OWNER", "REPO")) output, err := runCommand(http, tc.branch, false, tc.args) if err != nil { @@ -608,14 +606,12 @@ func TestPRView_Preview(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - t.Cleanup(shared.ResetRunCommandFinder) - http := &httpmock.Registry{} defer http.Verify(t) pr, err := prFromFixtures(tc.fixtures) require.NoError(t, err) - shared.RunCommandFinder("12", pr, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "12", pr, ghrepo.New("OWNER", "REPO")) output, err := runCommand(http, tc.branch, true, tc.args) if err != nil { @@ -638,7 +634,7 @@ func TestPRView_web_currentBranch(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", &api.PullRequest{URL: "https://github.com/OWNER/REPO/pull/10"}, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{URL: "https://github.com/OWNER/REPO/pull/10"}, ghrepo.New("OWNER", "REPO")) _, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -657,7 +653,7 @@ func TestPRView_web_noResultsForBranch(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", nil, nil) + shared.StubFinderForRunCommandStyleTests(t, "", nil, nil) _, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -749,9 +745,9 @@ func TestPRView_tty_Comments(t *testing.T) { if len(tt.fixtures) > 0 { pr, err := prFromFixtures(tt.fixtures) require.NoError(t, err) - shared.RunCommandFinder("123", pr, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, ghrepo.New("OWNER", "REPO")) } else { - shared.RunCommandFinder("123", nil, nil) + shared.StubFinderForRunCommandStyleTests(t, "123", nil, nil) } output, err := runCommand(http, tt.branch, true, tt.cli) @@ -854,17 +850,15 @@ func TestPRView_nontty_Comments(t *testing.T) { } for name, tt := range tests { t.Run(name, func(t *testing.T) { - t.Cleanup(shared.ResetRunCommandFinder) - http := &httpmock.Registry{} defer http.Verify(t) if len(tt.fixtures) > 0 { pr, err := prFromFixtures(tt.fixtures) require.NoError(t, err) - shared.RunCommandFinder("123", pr, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, ghrepo.New("OWNER", "REPO")) } else { - shared.RunCommandFinder("123", nil, nil) + shared.StubFinderForRunCommandStyleTests(t, "123", nil, nil) } output, err := runCommand(http, tt.branch, false, tt.cli) From e995a873cb7ad4be48c1ab1fbf51a85b7d0280c4 Mon Sep 17 00:00:00 2001 From: William Martin Date: Thu, 1 May 2025 15:58:48 +0200 Subject: [PATCH 46/51] Feature detect v1 projects on non-interactive pr create --- api/queries_repo.go | 41 ++++++----- api/queries_repo_test.go | 8 +- pkg/cmd/pr/create/create.go | 30 ++++++-- pkg/cmd/pr/create/create_test.go | 121 ++++++++++++++++++++++++++++++- pkg/cmd/pr/shared/editable.go | 6 +- pkg/cmd/pr/shared/params.go | 4 +- 6 files changed, 174 insertions(+), 36 deletions(-) diff --git a/api/queries_repo.go b/api/queries_repo.go index 27e21eb32..93a32d80c 100644 --- a/api/queries_repo.go +++ b/api/queries_repo.go @@ -738,34 +738,37 @@ func (m *RepoMetadataResult) LabelsToIDs(names []string) ([]string, error) { return ids, nil } -// ProjectsToIDs returns two arrays: +// ProjectsTitlesToIDs returns two arrays: // - the first contains IDs of projects V1 // - the second contains IDs of projects V2 // - if neither project V1 or project V2 can be found with a given name, then an error is returned -func (m *RepoMetadataResult) ProjectsToIDs(names []string) ([]string, []string, error) { +func (m *RepoMetadataResult) ProjectsTitlesToIDs(titles []string) ([]string, []string, error) { var ids []string var idsV2 []string - for _, projectName := range names { - id, found := m.projectNameToID(projectName) + for _, title := range titles { + id, found := m.v1ProjectNameToID(title) if found { ids = append(ids, id) continue } - idV2, found := m.projectV2TitleToID(projectName) + idV2, found := m.v2ProjectTitleToID(title) if found { idsV2 = append(idsV2, idV2) continue } - return nil, nil, fmt.Errorf("'%s' not found", projectName) + return nil, nil, fmt.Errorf("'%s' not found", title) } return ids, idsV2, nil } -func (m *RepoMetadataResult) projectNameToID(projectName string) (string, bool) { +// We use the word "titles" when referring to v1 and v2 projects. +// In reality, v1 projects really have "names", so there is a bit of a +// mismatch we just need to gloss over. +func (m *RepoMetadataResult) v1ProjectNameToID(name string) (string, bool) { for _, p := range m.Projects { - if strings.EqualFold(projectName, p.Name) { + if strings.EqualFold(name, p.Name) { return p.ID, true } } @@ -773,9 +776,9 @@ func (m *RepoMetadataResult) projectNameToID(projectName string) (string, bool) return "", false } -func (m *RepoMetadataResult) projectV2TitleToID(projectTitle string) (string, bool) { +func (m *RepoMetadataResult) v2ProjectTitleToID(title string) (string, bool) { for _, p := range m.ProjectsV2 { - if strings.EqualFold(projectTitle, p.Title) { + if strings.EqualFold(title, p.Title) { return p.ID, true } } @@ -783,8 +786,8 @@ func (m *RepoMetadataResult) projectV2TitleToID(projectTitle string) (string, bo return "", false } -func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []string, projectsV1Support gh.ProjectsV1Support) ([]string, error) { - paths := make([]string, 0, len(projectNames)) +func ProjectTitlesToPaths(client *Client, repo ghrepo.Interface, titles []string, projectsV1Support gh.ProjectsV1Support) ([]string, error) { + paths := make([]string, 0, len(titles)) matchedPaths := map[string]struct{}{} // TODO: ProjectsV1Cleanup @@ -796,9 +799,9 @@ func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []s return nil, err } - for _, projectName := range projectNames { + for _, title := range titles { for _, p := range v1Projects { - if strings.EqualFold(projectName, p.Name) { + if strings.EqualFold(title, p.Name) { pathParts := strings.Split(p.ResourcePath, "/") var path string if pathParts[1] == "orgs" || pathParts[1] == "users" { @@ -807,7 +810,7 @@ func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []s path = fmt.Sprintf("%s/%s/%s", pathParts[1], pathParts[2], pathParts[4]) } paths = append(paths, path) - matchedPaths[projectName] = struct{}{} + matchedPaths[title] = struct{}{} break } } @@ -820,15 +823,15 @@ func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []s return nil, err } - for _, projectName := range projectNames { + for _, title := range titles { // If we already found a v1 project with this name, skip it - if _, ok := matchedPaths[projectName]; ok { + if _, ok := matchedPaths[title]; ok { continue } found := false for _, p := range v2Projects { - if strings.EqualFold(projectName, p.Title) { + if strings.EqualFold(title, p.Title) { pathParts := strings.Split(p.ResourcePath, "/") var path string if pathParts[1] == "orgs" || pathParts[1] == "users" { @@ -843,7 +846,7 @@ func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []s } if !found { - return nil, fmt.Errorf("'%s' not found", projectName) + return nil, fmt.Errorf("'%s' not found", title) } } diff --git a/api/queries_repo_test.go b/api/queries_repo_test.go index 72ed35776..01fc7a4c7 100644 --- a/api/queries_repo_test.go +++ b/api/queries_repo_test.go @@ -187,7 +187,7 @@ func Test_RepoMetadata(t *testing.T) { expectedProjectIDs := []string{"TRIAGEID", "ROADMAPID"} expectedProjectV2IDs := []string{"TRIAGEV2ID", "ROADMAPV2ID", "MONALISAV2ID"} - projectIDs, projectV2IDs, err := result.ProjectsToIDs([]string{"triage", "roadmap", "triagev2", "roadmapv2", "monalisav2"}) + projectIDs, projectV2IDs, err := result.ProjectsTitlesToIDs([]string{"triage", "roadmap", "triagev2", "roadmapv2", "monalisav2"}) if err != nil { t.Errorf("error resolving projects: %v", err) } @@ -273,7 +273,7 @@ func Test_ProjectNamesToPaths(t *testing.T) { } } } } `)) - projectPaths, err := ProjectNamesToPaths(client, repo, []string{"Triage", "Roadmap", "TriageV2", "RoadmapV2", "MonalisaV2"}, gh.ProjectsV1Supported) + projectPaths, err := ProjectTitlesToPaths(client, repo, []string{"Triage", "Roadmap", "TriageV2", "RoadmapV2", "MonalisaV2"}, gh.ProjectsV1Supported) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -331,7 +331,7 @@ func Test_ProjectNamesToPaths(t *testing.T) { } } } } `)) - projectPaths, err := ProjectNamesToPaths(client, repo, []string{"TriageV2", "RoadmapV2", "MonalisaV2"}, gh.ProjectsV1Unsupported) + projectPaths, err := ProjectTitlesToPaths(client, repo, []string{"TriageV2", "RoadmapV2", "MonalisaV2"}, gh.ProjectsV1Unsupported) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -374,7 +374,7 @@ func Test_ProjectNamesToPaths(t *testing.T) { } } } } `)) - _, err := ProjectNamesToPaths(client, repo, []string{"TriageV2"}, gh.ProjectsV1Unsupported) + _, err := ProjectTitlesToPaths(client, repo, []string{"TriageV2"}, gh.ProjectsV1Unsupported) require.Equal(t, err, fmt.Errorf("'TriageV2' not found")) }) } diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 7f960bce4..483b8246b 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -18,6 +18,7 @@ import ( ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/browser" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/text" @@ -31,6 +32,7 @@ import ( type CreateOptions struct { // This struct stores user input and factory functions + Detector fd.Detector HttpClient func() (*http.Client, error) GitClient *git.Client Config func() (gh.Config, error) @@ -363,6 +365,20 @@ func createRun(opts *CreateOptions) error { return err } + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + + // TODO projectsV1Deprecation + // Remove this section as we should no longer need to detect + if opts.Detector == nil { + cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) + opts.Detector = fd.NewDetector(cachedClient, ctx.PRRefs.BaseRepo().RepoHost()) + } + + projectsV1Support := opts.Detector.ProjectsV1() + client := ctx.Client state, err := NewIssueState(*ctx, *opts) @@ -384,7 +400,7 @@ func createRun(opts *CreateOptions) error { if err != nil { return err } - openURL, err = generateCompareURL(*ctx, *state) + openURL, err = generateCompareURL(*ctx, *state, gh.ProjectsV1Supported) if err != nil { return err } @@ -441,7 +457,7 @@ func createRun(opts *CreateOptions) error { return err } // TODO wm: revisit project support - return submitPR(*opts, *ctx, *state, gh.ProjectsV1Supported) + return submitPR(*opts, *ctx, *state, projectsV1Support) } if opts.RecoverFile != "" { @@ -518,7 +534,7 @@ func createRun(opts *CreateOptions) error { } } - openURL, err = generateCompareURL(*ctx, *state) + openURL, err = generateCompareURL(*ctx, *state, gh.ProjectsV1Supported) if err != nil { return err } @@ -568,12 +584,12 @@ func createRun(opts *CreateOptions) error { if action == shared.SubmitDraftAction { state.Draft = true // TODO wm: revisit project support - return submitPR(*opts, *ctx, *state, gh.ProjectsV1Supported) + return submitPR(*opts, *ctx, *state, projectsV1Support) } if action == shared.SubmitAction { // TODO wm: revisit project support - return submitPR(*opts, *ctx, *state, gh.ProjectsV1Supported) + return submitPR(*opts, *ctx, *state, projectsV1Support) } err = errors.New("expected to cancel, preview, or submit") @@ -1216,13 +1232,13 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { return pushBranch() } -func generateCompareURL(ctx CreateContext, state shared.IssueMetadataState) (string, error) { +func generateCompareURL(ctx CreateContext, state shared.IssueMetadataState, projectsV1Support gh.ProjectsV1Support) (string, error) { u := ghrepo.GenerateRepoURL( ctx.PRRefs.BaseRepo(), "compare/%s...%s?expand=1", url.PathEscape(ctx.PRRefs.BaseRef()), url.PathEscape(ctx.PRRefs.QualifiedHeadRef())) // TODO wm: revisit project support - url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PRRefs.BaseRepo(), u, state, gh.ProjectsV1Supported) + url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PRRefs.BaseRepo(), u, state, projectsV1Support) if err != nil { return "", err } diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 2a88b5eee..f3b99bc89 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -15,6 +15,7 @@ import ( "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/browser" "github.com/cli/cli/v2/internal/config" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/prompter" @@ -1618,6 +1619,7 @@ func Test_createRun(t *testing.T) { } opts := CreateOptions{} + opts.Detector = &fd.EnabledDetectorMock{} opts.Prompter = pm ios, _, stdout, stderr := iostreams.Test() @@ -1941,7 +1943,8 @@ func Test_generateCompareURL(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := generateCompareURL(tt.ctx, tt.state) + // TODO wm: projects v1 support? + got, err := generateCompareURL(tt.ctx, tt.state, gh.ProjectsV1Supported) if (err != nil) != tt.wantErr { t.Errorf("generateCompareURL() error = %v, wantErr %v", err, tt.wantErr) return @@ -2009,3 +2012,119 @@ func mockRetrieveProjects(_ *testing.T, reg *httpmock.Registry) { } // TODO interactive metadata tests once: 1) we have test utils for Prompter and 2) metadata questions use Prompter + +// TODO projectsV1Deprecation +// Remove this test. +func TestProjectsV1Deprecation(t *testing.T) { + + t.Run("non-interactive submission", func(t *testing.T) { + t.Run("when projects v1 is supported, queries for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "main") + reg.Register( + // ( is required to avoid matching projectsV2 + httpmock.GraphQL(`projects\(`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = createRun(&CreateOptions{ + Detector: &fd.EnabledDetectorMock{}, + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Finder: shared.NewMockFinder("feature", nil, nil), + + HeadBranch: "feature", + + TitleProvided: true, + BodyProvided: true, + Title: "Test Title", + Body: "Test Body", + + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request contained projects + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, does not query for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "main") + // ( is required to avoid matching projectsV2 + reg.Exclude(t, httpmock.GraphQL(`projects\(`)) + + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + + // Ignore the error because we're not really interested in it. + _ = createRun(&CreateOptions{ + Detector: &fd.DisabledDetectorMock{}, + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Finder: shared.NewMockFinder("feature", nil, nil), + + HeadBranch: "feature", + + TitleProvided: true, + BodyProvided: true, + Title: "Test Title", + Body: "Test Body", + + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) + }) +} diff --git a/pkg/cmd/pr/shared/editable.go b/pkg/cmd/pr/shared/editable.go index 0bebb999a..e73b3c294 100644 --- a/pkg/cmd/pr/shared/editable.go +++ b/pkg/cmd/pr/shared/editable.go @@ -137,7 +137,7 @@ func (e Editable) ProjectIds() (*[]string, error) { s.RemoveValues(e.Projects.Remove) e.Projects.Value = s.ToSlice() } - p, _, err := e.Metadata.ProjectsToIDs(e.Projects.Value) + p, _, err := e.Metadata.ProjectsTitlesToIDs(e.Projects.Value) return &p, err } @@ -171,14 +171,14 @@ func (e Editable) ProjectV2Ids() (*[]string, *[]string, error) { var err error if addTitles.Len() > 0 { - _, addIds, err = e.Metadata.ProjectsToIDs(addTitles.ToSlice()) + _, addIds, err = e.Metadata.ProjectsTitlesToIDs(addTitles.ToSlice()) if err != nil { return nil, nil, err } } if removeTitles.Len() > 0 { - _, removeIds, err = e.Metadata.ProjectsToIDs(removeTitles.ToSlice()) + _, removeIds, err = e.Metadata.ProjectsTitlesToIDs(removeTitles.ToSlice()) if err != nil { return nil, nil, err } diff --git a/pkg/cmd/pr/shared/params.go b/pkg/cmd/pr/shared/params.go index 4f36a80aa..08968939d 100644 --- a/pkg/cmd/pr/shared/params.go +++ b/pkg/cmd/pr/shared/params.go @@ -36,7 +36,7 @@ func WithPrAndIssueQueryParams(client *api.Client, baseRepo ghrepo.Interface, ba q.Set("labels", strings.Join(state.Labels, ",")) } if len(state.ProjectTitles) > 0 { - projectPaths, err := api.ProjectNamesToPaths(client, baseRepo, state.ProjectTitles, projectsV1Support) + projectPaths, err := api.ProjectTitlesToPaths(client, baseRepo, state.ProjectTitles, projectsV1Support) if err != nil { return "", fmt.Errorf("could not add to project: %w", err) } @@ -119,7 +119,7 @@ func AddMetadataToIssueParams(client *api.Client, baseRepo ghrepo.Interface, par } params["labelIds"] = labelIDs - projectIDs, projectV2IDs, err := tb.MetadataResult.ProjectsToIDs(tb.ProjectTitles) + projectIDs, projectV2IDs, err := tb.MetadataResult.ProjectsTitlesToIDs(tb.ProjectTitles) if err != nil { return fmt.Errorf("could not add to project: %w", err) } From 9822bb5d07fd9f2829572102a01318ca4c3bc909 Mon Sep 17 00:00:00 2001 From: William Martin Date: Thu, 1 May 2025 17:31:42 +0200 Subject: [PATCH 47/51] Feature detect v1 projects on web mode pr create --- pkg/cmd/pr/create/create.go | 2 +- pkg/cmd/pr/create/create_test.go | 115 +++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 483b8246b..6ce11a712 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -400,7 +400,7 @@ func createRun(opts *CreateOptions) error { if err != nil { return err } - openURL, err = generateCompareURL(*ctx, *state, gh.ProjectsV1Supported) + openURL, err = generateCompareURL(*ctx, *state, projectsV1Support) if err != nil { return err } diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index f3b99bc89..59a974df6 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -2127,4 +2127,119 @@ func TestProjectsV1Deprecation(t *testing.T) { reg.Verify(t) }) }) + + t.Run("web mode", func(t *testing.T) { + t.Run("when projects v1 is supported, queries for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "main") + reg.Register( + // ( is required to avoid matching projectsV2 + httpmock.GraphQL(`projects\(`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = createRun(&CreateOptions{ + Detector: &fd.EnabledDetectorMock{}, + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Finder: shared.NewMockFinder("feature", nil, nil), + + WebMode: true, + + HeadBranch: "feature", + + TitleProvided: true, + BodyProvided: true, + Title: "Test Title", + Body: "Test Body", + + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request contained projects + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, does not query for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "main") + // ( is required to avoid matching projectsV2 + reg.Exclude(t, httpmock.GraphQL(`projects\(`)) + + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + + // Ignore the error because we're not really interested in it. + _ = createRun(&CreateOptions{ + Detector: &fd.DisabledDetectorMock{}, + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Finder: shared.NewMockFinder("feature", nil, nil), + + WebMode: true, + + HeadBranch: "feature", + + TitleProvided: true, + BodyProvided: true, + Title: "Test Title", + Body: "Test Body", + + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) + }) } From 5a3aee056a9f0661d7c088a4a634b03e30d2d8eb Mon Sep 17 00:00:00 2001 From: William Martin Date: Fri, 2 May 2025 16:43:36 +0200 Subject: [PATCH 48/51] Feature detect v1 projects on interactive pr create --- internal/prompter/test.go | 12 ++ pkg/cmd/pr/create/create.go | 7 +- pkg/cmd/pr/create/create_test.go | 210 ++++++++++++++++++++++++++++++- 3 files changed, 219 insertions(+), 10 deletions(-) diff --git a/internal/prompter/test.go b/internal/prompter/test.go index 04375ce76..dfa124fca 100644 --- a/internal/prompter/test.go +++ b/internal/prompter/test.go @@ -141,6 +141,18 @@ func IndexFor(options []string, answer string) (int, error) { return -1, NoSuchAnswerErr(answer, options) } +func IndexesFor(options []string, answers ...string) ([]int, error) { + indexes := make([]int, len(answers)) + for i, answer := range answers { + index, err := IndexFor(options, answer) + if err != nil { + return nil, err + } + indexes[i] = index + } + return indexes, nil +} + func NoSuchPromptErr(prompt string) error { return fmt.Errorf("no such prompt '%s'", prompt) } diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 6ce11a712..64469543f 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -456,7 +456,6 @@ func createRun(opts *CreateOptions) error { if err != nil { return err } - // TODO wm: revisit project support return submitPR(*opts, *ctx, *state, projectsV1Support) } @@ -553,8 +552,7 @@ func createRun(opts *CreateOptions) error { Repo: ctx.PRRefs.BaseRepo(), State: state, } - // TODO wm: revisit project support - err = shared.MetadataSurvey(opts.Prompter, opts.IO, ctx.PRRefs.BaseRepo(), fetcher, state, gh.ProjectsV1Supported) + err = shared.MetadataSurvey(opts.Prompter, opts.IO, ctx.PRRefs.BaseRepo(), fetcher, state, projectsV1Support) if err != nil { return err } @@ -583,12 +581,10 @@ func createRun(opts *CreateOptions) error { if action == shared.SubmitDraftAction { state.Draft = true - // TODO wm: revisit project support return submitPR(*opts, *ctx, *state, projectsV1Support) } if action == shared.SubmitAction { - // TODO wm: revisit project support return submitPR(*opts, *ctx, *state, projectsV1Support) } @@ -1237,7 +1233,6 @@ func generateCompareURL(ctx CreateContext, state shared.IssueMetadataState, proj ctx.PRRefs.BaseRepo(), "compare/%s...%s?expand=1", url.PathEscape(ctx.PRRefs.BaseRef()), url.PathEscape(ctx.PRRefs.QualifiedHeadRef())) - // TODO wm: revisit project support url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PRRefs.BaseRepo(), u, state, projectsV1Support) if err != nil { return "", err diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 59a974df6..ac96db0d6 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1943,7 +1943,6 @@ func Test_generateCompareURL(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // TODO wm: projects v1 support? got, err := generateCompareURL(tt.ctx, tt.state, gh.ProjectsV1Supported) if (err != nil) != tt.wantErr { t.Errorf("generateCompareURL() error = %v, wantErr %v", err, tt.wantErr) @@ -2011,8 +2010,6 @@ func mockRetrieveProjects(_ *testing.T, reg *httpmock.Registry) { `)) } -// TODO interactive metadata tests once: 1) we have test utils for Prompter and 2) metadata questions use Prompter - // TODO projectsV1Deprecation // Remove this test. func TestProjectsV1Deprecation(t *testing.T) { @@ -2128,6 +2125,211 @@ func TestProjectsV1Deprecation(t *testing.T) { }) }) + t.Run("interactive submission", func(t *testing.T) { + t.Run("when projects v1 is supported, queries for it", func(t *testing.T) { + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + cs.Register("git -c log.ShowSignature=false log --pretty=format:%H%x00%s%x00%b%x00 --cherry origin/master...feature", 0, "") + cs.Register(`git rev-parse --show-toplevel`, 0, "") + + // When the command is run + reg := &httpmock.Registry{} + reg.StubRepoResponse("OWNER", "REPO") + + reg.Register( + httpmock.GraphQL(`query PullRequestTemplates\b`), + httpmock.StringResponse(`{ "data": { "repository": { "pullRequestTemplates": [] } } }`), + ) + + reg.Register( + // ( is required to avoid matching projectsV2 + httpmock.GraphQL(`projects\(`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + // Register a handler to check for projects V2 just to avoid the registry panicking, even + // though we return a 500 error. This is because the project lookup is done in parallel + // so the previous error doesn't early exit. + reg.Register( + httpmock.GraphQL(`projectsV2`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + ios, _, _, _ := iostreams.Test() + ios.SetStdinTTY(true) + ios.SetStdoutTTY(true) + ios.SetStderrTTY(true) + + pm := &prompter.PrompterMock{} + pm.InputFunc = func(p, _ string) (string, error) { + if p == "Title (required)" { + return "Test Title", nil + } else { + return "", prompter.NoSuchPromptErr(p) + } + } + pm.MarkdownEditorFunc = func(p, _ string, ba bool) (string, error) { + if p == "Body" { + return "Test Body", nil + } else { + return "", prompter.NoSuchPromptErr(p) + } + } + pm.SelectFunc = func(p, _ string, opts []string) (int, error) { + switch p { + case "Choose a template": + return 0, nil + case "What's next?": + return prompter.IndexFor(opts, "Add metadata") + default: + return -1, prompter.NoSuchPromptErr(p) + } + } + pm.MultiSelectFunc = func(p string, _ []string, opts []string) ([]int, error) { + return prompter.IndexesFor(opts, "Projects") + } + + opts := CreateOptions{ + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + Config: func() (gh.Config, error) { + return config.NewBlankConfig(), nil + }, + Browser: &browser.Stub{}, + IO: ios, + Prompter: pm, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Finder: shared.NewMockFinder("feature", nil, nil), + Detector: &fd.EnabledDetectorMock{}, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Branch: func() (string, error) { + return "feature", nil + }, + + HeadBranch: "feature", + } + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = createRun(&opts) + + // Verify that our request contained projects + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, does not query for it", func(t *testing.T) { + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + cs.Register("git -c log.ShowSignature=false log --pretty=format:%H%x00%s%x00%b%x00 --cherry origin/master...feature", 0, "") + cs.Register(`git rev-parse --show-toplevel`, 0, "") + + // When the command is run + reg := &httpmock.Registry{} + reg.StubRepoResponse("OWNER", "REPO") + + reg.Register( + httpmock.GraphQL(`query PullRequestTemplates\b`), + httpmock.StringResponse(`{ "data": { "repository": { "pullRequestTemplates": [] } } }`), + ) + + // ( is required to avoid matching projectsV2 + reg.Exclude(t, httpmock.GraphQL(`projects\(`)) + + ios, _, _, _ := iostreams.Test() + ios.SetStdinTTY(true) + ios.SetStdoutTTY(true) + ios.SetStderrTTY(true) + + pm := &prompter.PrompterMock{} + pm.InputFunc = func(p, _ string) (string, error) { + if p == "Title (required)" { + return "Test Title", nil + } else { + return "", prompter.NoSuchPromptErr(p) + } + } + pm.MarkdownEditorFunc = func(p, _ string, ba bool) (string, error) { + if p == "Body" { + return "Test Body", nil + } else { + return "", prompter.NoSuchPromptErr(p) + } + } + pm.SelectFunc = func(p, _ string, opts []string) (int, error) { + switch p { + case "Choose a template": + return 0, nil + case "What's next?": + return prompter.IndexFor(opts, "Add metadata") + default: + return -1, prompter.NoSuchPromptErr(p) + } + } + pm.MultiSelectFunc = func(p string, _ []string, opts []string) ([]int, error) { + return prompter.IndexesFor(opts, "Projects") + } + + opts := CreateOptions{ + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + Config: func() (gh.Config, error) { + return config.NewBlankConfig(), nil + }, + Browser: &browser.Stub{}, + IO: ios, + Prompter: pm, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Finder: shared.NewMockFinder("feature", nil, nil), + Detector: &fd.DisabledDetectorMock{}, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Branch: func() (string, error) { + return "feature", nil + }, + + HeadBranch: "feature", + } + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = createRun(&opts) + + // Verify that our request did not contain projectCards + reg.Verify(t) + }) + }) + t.Run("web mode", func(t *testing.T) { t.Run("when projects v1 is supported, queries for it", func(t *testing.T) { ios, _, _, _ := iostreams.Test() @@ -2238,7 +2440,7 @@ func TestProjectsV1Deprecation(t *testing.T) { Projects: []string{"Project"}, }) - // Verify that our request contained projectCards + // Verify that our request did not contain projectCards reg.Verify(t) }) }) From 1a5b7ca60c25d0484f3d6efc419669162083da9a Mon Sep 17 00:00:00 2001 From: William Martin Date: Fri, 2 May 2025 16:58:46 +0200 Subject: [PATCH 49/51] Feature detect v1 projects for preview URL As far as I can see, when there is project metadata, the preview option will never be shown in the interactive multiselect, so I don't believe this change has any functional difference. However, I did use the opportunity to drive out tests for generateCompareURL --- pkg/cmd/pr/create/create.go | 2 +- pkg/cmd/pr/create/create_test.go | 139 +++++++++++++++++++++++++++++-- 2 files changed, 134 insertions(+), 7 deletions(-) diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 64469543f..1d980b68d 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -533,7 +533,7 @@ func createRun(opts *CreateOptions) error { } } - openURL, err = generateCompareURL(*ctx, *state, gh.ProjectsV1Supported) + openURL, err = generateCompareURL(*ctx, *state, projectsV1Support) if err != nil { return err } diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index ac96db0d6..bd68f19d9 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -1852,11 +1852,13 @@ func mustParseQualifiedHeadRef(ref string) shared.QualifiedHeadRef { func Test_generateCompareURL(t *testing.T) { tests := []struct { - name string - ctx CreateContext - state shared.IssueMetadataState - want string - wantErr bool + name string + ctx CreateContext + state shared.IssueMetadataState + httpStubs func(*testing.T, *httpmock.Registry) + projectsV1Support gh.ProjectsV1Support + want string + wantErr bool }{ { name: "basic", @@ -1940,10 +1942,135 @@ func Test_generateCompareURL(t *testing.T) { want: "https://github.com/OWNER/REPO/compare/main...feature?body=&expand=1&template=story.md", wantErr: false, }, + // TODO projectsV1Deprecation + // Clean up these tests, but probably keep one for general project ID resolution. + { + name: "with projects, no v1 support", + ctx: CreateContext{ + PRRefs: &skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("feature"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + baseBranchName: "main", + }, + }, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + // Ensure no v1 projects are requestd + // ( is required to avoid matching projectsV2 + reg.Exclude(t, httpmock.GraphQL(`projects\(`)) + reg.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "ProjectTitle", "id": "PROJECTV2ID", "resourcePath": "/OWNER/REPO/projects/3" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query UserProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "viewer": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + }, + state: shared.IssueMetadataState{ + ProjectTitles: []string{"ProjectTitle"}, + }, + projectsV1Support: gh.ProjectsV1Unsupported, + want: "https://github.com/OWNER/REPO/compare/main...feature?body=&expand=1&projects=OWNER%2FREPO%2F3", + wantErr: false, + }, + { + name: "with projects, v1 support", + ctx: CreateContext{ + PRRefs: &skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("feature"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + baseBranchName: "main", + }, + }, + }, + state: shared.IssueMetadataState{ + ProjectTitles: []string{"ProjectV1Title"}, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + // v1 project query responses + reg.Register( + httpmock.GraphQL(`query RepositoryProjectList\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projects": { + "nodes": [ + { "name": "ProjectV1Title", "id": "PROJECTV1ID", "resourcePath": "/OWNER/REPO/projects/1" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectList\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projects": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + // v2 project query responses + reg.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query UserProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "viewer": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + }, + projectsV1Support: gh.ProjectsV1Supported, + want: "https://github.com/OWNER/REPO/compare/main...feature?body=&expand=1&projects=OWNER%2FREPO%2F1", + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := generateCompareURL(tt.ctx, tt.state, gh.ProjectsV1Supported) + // If http stubs are provided, register them and inject the registry into a client + // that is provided to generateCompareURL in the ctx. + if tt.httpStubs != nil { + reg := &httpmock.Registry{} + defer reg.Verify(t) + + tt.httpStubs(t, reg) + tt.ctx.Client = api.NewClientFromHTTP(&http.Client{Transport: reg}) + } + + got, err := generateCompareURL(tt.ctx, tt.state, tt.projectsV1Support) if (err != nil) != tt.wantErr { t.Errorf("generateCompareURL() error = %v, wantErr %v", err, tt.wantErr) return From cc673cfaba6c6fd023271941fa98393b28301cf2 Mon Sep 17 00:00:00 2001 From: Kynan Ware <47394200+BagToad@users.noreply.github.com> Date: Fri, 2 May 2025 14:48:07 -0600 Subject: [PATCH 50/51] test(prompter): add timeout before password input --- internal/prompter/accessible_prompter_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/internal/prompter/accessible_prompter_test.go b/internal/prompter/accessible_prompter_test.go index 00947b8f4..ed4da8de8 100644 --- a/internal/prompter/accessible_prompter_test.go +++ b/internal/prompter/accessible_prompter_test.go @@ -32,6 +32,9 @@ import ( // are sufficient to ensure that the accessible prompter behaves roughly as expected // but doesn't mandate that prompts always look exactly the same. func TestAccessiblePrompter(t *testing.T) { + + beforePasswordSendTimeout := 20 * time.Microsecond + t.Run("Select", func(t *testing.T) { console := newTestVirtualTerminal(t) p := newTestAccessiblePrompter(t, console) @@ -147,6 +150,9 @@ func TestAccessiblePrompter(t *testing.T) { _, err := console.ExpectString("Enter password") require.NoError(t, err) + // Wait to ensure huh has time to set the echo mode + time.Sleep(beforePasswordSendTimeout) + // Enter a number _, err = console.SendLine(dummyPassword) require.NoError(t, err) @@ -210,6 +216,9 @@ func TestAccessiblePrompter(t *testing.T) { _, err := console.ExpectString("Paste your authentication token:") require.NoError(t, err) + // Wait to ensure huh has time to set the echo mode + time.Sleep(beforePasswordSendTimeout) + // Enter some dummy auth token _, err = console.SendLine(dummyAuthToken) require.NoError(t, err) @@ -243,6 +252,9 @@ func TestAccessiblePrompter(t *testing.T) { _, err = console.ExpectString("token is required") require.NoError(t, err) + // Wait to ensure huh has time to set the echo mode + time.Sleep(beforePasswordSendTimeout) + // Now enter some dummy auth token to return control back to the test _, err = console.SendLine(dummyAuthTokenForAfterFailure) require.NoError(t, err) From ee281fd9bacd77b3bdd37db127f7caf083716a71 Mon Sep 17 00:00:00 2001 From: Azeem Sajid Date: Wed, 7 May 2025 17:59:22 +0500 Subject: [PATCH 51/51] Add `closedByPullRequestsReferences` JSON field to `issue view` (#10941) * [gh issue view] Expose `closedByPullRequestsReferences` JSON fields * Incorporate GitHub Copilot review suggestions * Incorporate review changes --- api/export_pr.go | 19 ++++++++- api/export_pr_test.go | 73 ++++++++++++++++++++++++++++++++- api/queries_issue.go | 22 ++++++++++ api/query_builder.go | 22 ++++++++++ pkg/cmd/issue/view/http.go | 39 ++++++++++++++++++ pkg/cmd/issue/view/view.go | 22 ++++++---- pkg/cmd/issue/view/view_test.go | 1 + 7 files changed, 187 insertions(+), 11 deletions(-) diff --git a/api/export_pr.go b/api/export_pr.go index 7ae1a4ff4..9b030c39e 100644 --- a/api/export_pr.go +++ b/api/export_pr.go @@ -28,6 +28,24 @@ func (issue *Issue) ExportData(fields []string) map[string]interface{} { }) } data[f] = items + case "closedByPullRequestsReferences": + items := make([]map[string]interface{}, 0, len(issue.ClosedByPullRequestsReferences.Nodes)) + for _, n := range issue.ClosedByPullRequestsReferences.Nodes { + items = append(items, map[string]interface{}{ + "id": n.ID, + "number": n.Number, + "url": n.URL, + "repository": map[string]interface{}{ + "id": n.Repository.ID, + "name": n.Repository.Name, + "owner": map[string]interface{}{ + "id": n.Repository.Owner.ID, + "login": n.Repository.Owner.Login, + }, + }, + }) + } + data[f] = items default: sf := fieldByName(v, f) data[f] = sf.Interface() @@ -143,7 +161,6 @@ func (pr *PullRequest) ExportData(fields []string) map[string]interface{} { items := make([]map[string]interface{}, 0, len(pr.ClosingIssuesReferences.Nodes)) for _, n := range pr.ClosingIssuesReferences.Nodes { items = append(items, map[string]interface{}{ - "id": n.ID, "number": n.Number, "url": n.URL, diff --git a/api/export_pr_test.go b/api/export_pr_test.go index 09a1dffe8..1f310693e 100644 --- a/api/export_pr_test.go +++ b/api/export_pr_test.go @@ -107,6 +107,70 @@ func TestIssue_ExportData(t *testing.T) { } `), }, + { + name: "linked pull requests", + fields: []string{"closedByPullRequestsReferences"}, + inputJSON: heredoc.Doc(` + { "closedByPullRequestsReferences": { "nodes": [ + { + "id": "I_123", + "number": 123, + "url": "https://github.com/cli/cli/pull/123", + "repository": { + "id": "R_123", + "name": "cli", + "owner": { + "id": "O_123", + "login": "cli" + } + } + }, + { + "id": "I_456", + "number": 456, + "url": "https://github.com/cli/cli/pull/456", + "repository": { + "id": "R_456", + "name": "cli", + "owner": { + "id": "O_456", + "login": "cli" + } + } + } + ] } } + `), + outputJSON: heredoc.Doc(` + { "closedByPullRequestsReferences": [ + { + "id": "I_123", + "number": 123, + "repository": { + "id": "R_123", + "name": "cli", + "owner": { + "id": "O_123", + "login": "cli" + } + }, + "url": "https://github.com/cli/cli/pull/123" + }, + { + "id": "I_456", + "number": 456, + "repository": { + "id": "R_456", + "name": "cli", + "owner": { + "id": "O_456", + "login": "cli" + } + }, + "url": "https://github.com/cli/cli/pull/456" + } + ] } + `), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -120,7 +184,14 @@ func TestIssue_ExportData(t *testing.T) { enc := json.NewEncoder(&buf) enc.SetIndent("", "\t") require.NoError(t, enc.Encode(exported)) - assert.Equal(t, tt.outputJSON, buf.String()) + + var gotData interface{} + dec = json.NewDecoder(&buf) + require.NoError(t, dec.Decode(&gotData)) + var expectData interface{} + require.NoError(t, json.Unmarshal([]byte(tt.outputJSON), &expectData)) + + assert.Equal(t, expectData, gotData) }) } } diff --git a/api/queries_issue.go b/api/queries_issue.go index 094b6b198..f09360152 100644 --- a/api/queries_issue.go +++ b/api/queries_issue.go @@ -44,6 +44,28 @@ type Issue struct { Milestone *Milestone ReactionGroups ReactionGroups IsPinned bool + + ClosedByPullRequestsReferences ClosedByPullRequestsReferences +} + +type ClosedByPullRequestsReferences struct { + Nodes []struct { + ID string + Number int + URL string + Repository struct { + ID string + Name string + Owner struct { + ID string + Login string + } + } + } + PageInfo struct { + HasNextPage bool + EndCursor string + } } // return values for Issue.Typename diff --git a/api/query_builder.go b/api/query_builder.go index 4c45da3c1..47fb4c225 100644 --- a/api/query_builder.go +++ b/api/query_builder.go @@ -56,6 +56,25 @@ var issueCommentLast = shortenQuery(` } `) +var issueClosedByPullRequestsReferences = shortenQuery(` + closedByPullRequestsReferences(first: 100) { + nodes { + id, + number, + url, + repository { + id, + name, + owner { + id, + login + } + } + } + pageInfo{hasNextPage,endCursor} + } +`) + var prReviewRequests = shortenQuery(` reviewRequests(first: 100) { nodes { @@ -296,6 +315,7 @@ var sharedIssuePRFields = []string{ var issueOnlyFields = []string{ "isPinned", "stateReason", + "closedByPullRequestsReferences", } var IssueFields = append(sharedIssuePRFields, issueOnlyFields...) @@ -388,6 +408,8 @@ func IssueGraphQL(fields []string) string { q = append(q, StatusCheckRollupGraphQLWithCountByState()) case "closingIssuesReferences": q = append(q, prClosingIssuesReferences) + case "closedByPullRequestsReferences": + q = append(q, issueClosedByPullRequestsReferences) default: q = append(q, field) } diff --git a/pkg/cmd/issue/view/http.go b/pkg/cmd/issue/view/http.go index e4f756436..4adc71802 100644 --- a/pkg/cmd/issue/view/http.go +++ b/pkg/cmd/issue/view/http.go @@ -53,3 +53,42 @@ func preloadIssueComments(client *http.Client, repo ghrepo.Interface, issue *api issue.Comments.PageInfo.HasNextPage = false return nil } + +func preloadClosedByPullRequestsReferences(client *http.Client, repo ghrepo.Interface, issue *api.Issue) error { + if !issue.ClosedByPullRequestsReferences.PageInfo.HasNextPage { + return nil + } + + type response struct { + Node struct { + Issue struct { + ClosedByPullRequestsReferences api.ClosedByPullRequestsReferences `graphql:"closedByPullRequestsReferences(first: 100, after: $endCursor)"` + } `graphql:"...on Issue"` + } `graphql:"node(id: $id)"` + } + + variables := map[string]interface{}{ + "id": githubv4.ID(issue.ID), + "endCursor": githubv4.String(issue.ClosedByPullRequestsReferences.PageInfo.EndCursor), + } + + gql := api.NewClientFromHTTP(client) + + for { + var query response + err := gql.Query(repo.RepoHost(), "closedByPullRequestsReferences", &query, variables) + if err != nil { + return err + } + + issue.ClosedByPullRequestsReferences.Nodes = append(issue.ClosedByPullRequestsReferences.Nodes, query.Node.Issue.ClosedByPullRequestsReferences.Nodes...) + + if !query.Node.Issue.ClosedByPullRequestsReferences.PageInfo.HasNextPage { + break + } + variables["endCursor"] = githubv4.String(query.Node.Issue.ClosedByPullRequestsReferences.PageInfo.EndCursor) + } + + issue.ClosedByPullRequestsReferences.PageInfo.HasNextPage = false + return nil +} diff --git a/pkg/cmd/issue/view/view.go b/pkg/cmd/issue/view/view.go index a9e25513b..3b02a3f2d 100644 --- a/pkg/cmd/issue/view/view.go +++ b/pkg/cmd/issue/view/view.go @@ -1,7 +1,6 @@ package view import ( - "errors" "fmt" "io" "net/http" @@ -134,6 +133,8 @@ func viewRun(opts *ViewOptions) error { opts.IO.DetectTerminalTheme() opts.IO.StartProgressIndicator() + defer opts.IO.StopProgressIndicator() + lookupFields.Add("id") issue, err := issueShared.FindIssueOrPR(httpClient, baseRepo, opts.IssueNumber, lookupFields.ToSlice()) @@ -144,18 +145,21 @@ func viewRun(opts *ViewOptions) error { if lookupFields.Contains("comments") { // FIXME: this re-fetches the comments connection even though the initial set of 100 were // fetched in the previous request. - err = preloadIssueComments(httpClient, baseRepo, issue) - } - opts.IO.StopProgressIndicator() - if err != nil { - var loadErr *issueShared.PartialLoadError - if opts.Exporter == nil && errors.As(err, &loadErr) { - fmt.Fprintf(opts.IO.ErrOut, "warning: %s\n", loadErr.Error()) - } else { + err := preloadIssueComments(httpClient, baseRepo, issue) + if err != nil { return err } } + if lookupFields.Contains("closedByPullRequestsReferences") { + err := preloadClosedByPullRequestsReferences(httpClient, baseRepo, issue) + if err != nil { + return err + } + } + + opts.IO.StopProgressIndicator() + if opts.WebMode { openURL := issue.URL if opts.IO.IsStdoutTTY() { diff --git a/pkg/cmd/issue/view/view_test.go b/pkg/cmd/issue/view/view_test.go index 391a288fb..71b0884a1 100644 --- a/pkg/cmd/issue/view/view_test.go +++ b/pkg/cmd/issue/view/view_test.go @@ -31,6 +31,7 @@ func TestJSONFields(t *testing.T) { "body", "closed", "comments", + "closedByPullRequestsReferences", "createdAt", "closedAt", "id",