diff --git a/pkg/cmd/attestation/api/attestation.go b/pkg/cmd/attestation/api/attestation.go index fd6b484a7..daec12b50 100644 --- a/pkg/cmd/attestation/api/attestation.go +++ b/pkg/cmd/attestation/api/attestation.go @@ -1,7 +1,10 @@ package api import ( + "encoding/json" "errors" + "fmt" + "github.com/sigstore/sigstore-go/pkg/bundle" ) @@ -20,3 +23,35 @@ type Attestation struct { type AttestationsResponse struct { Attestations []*Attestation `json:"attestations"` } + +type IntotoStatement struct { + PredicateType string `json:"predicateType"` +} + +func FilterAttestations(predicateType string, attestations []*Attestation) ([]*Attestation, error) { + filteredAttestations := []*Attestation{} + + for _, each := range attestations { + dsseEnvelope := each.Bundle.GetDsseEnvelope() + if dsseEnvelope != nil { + if dsseEnvelope.PayloadType != "application/vnd.in-toto+json" { + // Don't fail just because an entry isn't intoto + continue + } + var intotoStatement IntotoStatement + if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &intotoStatement); err != nil { + // Don't fail just because a single entry can't be unmarshalled + continue + } + if intotoStatement.PredicateType == predicateType { + filteredAttestations = append(filteredAttestations, each) + } + } + } + + if len(filteredAttestations) == 0 { + return nil, fmt.Errorf("no attestations found with predicate type: %s", predicateType) + } + + return filteredAttestations, nil +} diff --git a/pkg/cmd/attestation/api/client.go b/pkg/cmd/attestation/api/client.go index 1e99a2a06..61d0bee52 100644 --- a/pkg/cmd/attestation/api/client.go +++ b/pkg/cmd/attestation/api/client.go @@ -27,6 +27,28 @@ const ( // Allow injecting backoff interval in tests. var getAttestationRetryInterval = time.Millisecond * 200 +// FetchParams are the parameters for fetching attestations from the GitHub API +type FetchParams struct { + Digest string + Limit int + Owner string + PredicateType string + Repo string +} + +func (p *FetchParams) Validate() error { + if p.Digest == "" { + return fmt.Errorf("digest must be provided") + } + if p.Limit <= 0 || p.Limit > maxLimitForFlag { + return fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag) + } + if p.Repo == "" && p.Owner == "" { + return fmt.Errorf("owner or repo must be provided") + } + return nil +} + // githubApiClient makes REST calls to the GitHub API type githubApiClient interface { REST(hostname, method, p string, body io.Reader, data interface{}) error @@ -39,8 +61,7 @@ type httpClient interface { } type Client interface { - GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) - GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) + GetByDigest(params FetchParams) ([]*Attestation, error) GetTrustDomain() (string, error) } @@ -60,22 +81,11 @@ func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClien } } -// GetByRepoAndDigest fetches the attestation by repo and digest -func (c *LiveClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) { - c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) - url := fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, repo, digest) - return c.getByURL(url, limit) -} - -// GetByOwnerAndDigest fetches attestation by owner and digest -func (c *LiveClient) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) { - c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) - url := fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, owner, digest) - return c.getByURL(url, limit) -} - -func (c *LiveClient) getByURL(url string, limit int) ([]*Attestation, error) { - attestations, err := c.getAttestations(url, limit) +// GetByDigest fetches the attestation by digest and either owner or repo +// depending on which is provided +func (c *LiveClient) GetByDigest(params FetchParams) ([]*Attestation, error) { + c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", params.Digest) + attestations, err := c.getAttestations(params) if err != nil { return nil, err } @@ -88,40 +98,52 @@ func (c *LiveClient) getByURL(url string, limit int) ([]*Attestation, error) { return bundles, nil } -// GetTrustDomain returns the current trust domain. If the default is used -// the empty string is returned -func (c *LiveClient) GetTrustDomain() (string, error) { - return c.getTrustDomain(MetaPath) -} - -func (c *LiveClient) getAttestations(url string, limit int) ([]*Attestation, error) { - perPage := limit - if perPage <= 0 || perPage > maxLimitForFlag { - return nil, fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag) +func (c *LiveClient) buildRequestURL(params FetchParams) (string, error) { + if err := params.Validate(); err != nil { + return "", err } + var url string + if params.Repo != "" { + // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. + // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. + url = fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, params.Repo, params.Digest) + } else { + url = fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, params.Owner, params.Digest) + } + + perPage := params.Limit if perPage > maxLimitForFetch { perPage = maxLimitForFetch } // ref: https://github.com/cli/go-gh/blob/d32c104a9a25c9de3d7c7b07a43ae0091441c858/example_gh_test.go#L96 url = fmt.Sprintf("%s?per_page=%d", url, perPage) + if params.PredicateType != "" { + url = fmt.Sprintf("%s&predicate_type=%s", url, params.PredicateType) + } + return url, nil +} + +func (c *LiveClient) getAttestations(params FetchParams) ([]*Attestation, error) { + url, err := c.buildRequestURL(params) + if err != nil { + return nil, err + } var attestations []*Attestation var resp AttestationsResponse bo := backoff.NewConstantBackOff(getAttestationRetryInterval) // if no attestation or less than limit, then keep fetching - for url != "" && len(attestations) < limit { + for url != "" && len(attestations) < params.Limit { err := backoff.Retry(func() error { newURL, restErr := c.githubAPI.RESTWithNext(c.host, http.MethodGet, url, nil, &resp) - if restErr != nil { if shouldRetry(restErr) { return restErr - } else { - return backoff.Permanent(restErr) } + return backoff.Permanent(restErr) } url = newURL @@ -140,8 +162,8 @@ func (c *LiveClient) getAttestations(url string, limit int) ([]*Attestation, err return nil, ErrNoAttestationsFound } - if len(attestations) > limit { - return attestations[:limit], nil + if len(attestations) > params.Limit { + return attestations[:params.Limit], nil } return attestations, nil @@ -241,6 +263,12 @@ func shouldRetry(err error) bool { return false } +// GetTrustDomain returns the current trust domain. If the default is used +// the empty string is returned +func (c *LiveClient) GetTrustDomain() (string, error) { + return c.getTrustDomain(MetaPath) +} + func (c *LiveClient) getTrustDomain(url string) (string, error) { var resp MetaResponse diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index 787408a4e..384c7c9c8 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -42,78 +42,75 @@ func NewClientWithMockGHClient(hasNextPage bool) Client { } } +var testFetchParamsWithOwner = FetchParams{ + Digest: testDigest, + Limit: DefaultLimit, + Owner: testOwner, + PredicateType: "https://slsa.dev/provenance/v1", +} +var testFetchParamsWithRepo = FetchParams{ + Digest: testDigest, + Limit: DefaultLimit, + Repo: testRepo, + PredicateType: "https://slsa.dev/provenance/v1", +} + +type getByTestCase struct { + name string + params FetchParams + limit int + expectedAttestations int + hasNextPage bool +} + +var getByTestCases = []getByTestCase{ + { + name: "get by digest with owner", + params: testFetchParamsWithOwner, + expectedAttestations: 5, + }, + { + name: "get by digest with repo", + params: testFetchParamsWithRepo, + expectedAttestations: 5, + }, + { + name: "get by digest with attestations greater than limit", + params: testFetchParamsWithRepo, + limit: 3, + expectedAttestations: 3, + }, + { + name: "get by digest with next page", + params: testFetchParamsWithRepo, + expectedAttestations: 10, + hasNextPage: true, + }, + { + name: "greater than limit with next page", + params: testFetchParamsWithRepo, + limit: 7, + expectedAttestations: 7, + hasNextPage: true, + }, +} + func TestGetByDigest(t *testing.T) { - c := NewClientWithMockGHClient(false) - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) - require.NoError(t, err) + for _, tc := range getByTestCases { + t.Run(tc.name, func(t *testing.T) { + c := NewClientWithMockGHClient(tc.hasNextPage) - require.Equal(t, 5, len(attestations)) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") + if tc.limit > 0 { + tc.params.Limit = tc.limit + } + attestations, err := c.GetByDigest(tc.params) + require.NoError(t, err) - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) - require.NoError(t, err) - - require.Equal(t, 5, len(attestations)) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") -} - -func TestGetByDigestGreaterThanLimit(t *testing.T) { - c := NewClientWithMockGHClient(false) - - limit := 3 - // The method should return five results when the limit is not set - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, limit) - require.NoError(t, err) - - require.Equal(t, 3, len(attestations)) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, limit) - require.NoError(t, err) - - require.Equal(t, len(attestations), limit) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") -} - -func TestGetByDigestWithNextPage(t *testing.T) { - c := NewClientWithMockGHClient(true) - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) - require.NoError(t, err) - - require.Equal(t, len(attestations), 10) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) - require.NoError(t, err) - - require.Equal(t, len(attestations), 10) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") -} - -func TestGetByDigestGreaterThanLimitWithNextPage(t *testing.T) { - c := NewClientWithMockGHClient(true) - - limit := 7 - // The method should return five results when the limit is not set - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, limit) - require.NoError(t, err) - - require.Equal(t, len(attestations), limit) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, limit) - require.NoError(t, err) - - require.Equal(t, len(attestations), limit) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") + require.Equal(t, tc.expectedAttestations, len(attestations)) + bundle := (attestations)[0].Bundle + require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") + }) + } } func TestGetByDigest_NoAttestationsFound(t *testing.T) { @@ -130,12 +127,7 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) - require.Error(t, err) - require.IsType(t, ErrNoAttestationsFound, err) - require.Nil(t, attestations) - - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) + attestations, err := c.GetByDigest(testFetchParamsWithRepo) require.Error(t, err) require.IsType(t, ErrNoAttestationsFound, err) require.Nil(t, attestations) @@ -153,11 +145,7 @@ func TestGetByDigest_Error(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) - require.Error(t, err) - require.Nil(t, attestations) - - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) + attestations, err := c.GetByDigest(testFetchParamsWithRepo) require.Error(t, err) require.Nil(t, attestations) } @@ -362,7 +350,8 @@ func TestGetAttestationsRetries(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + testFetchParamsWithRepo.Limit = 30 + attestations, err := c.GetByDigest(testFetchParamsWithRepo) require.NoError(t, err) // assert the error path was executed; because this is a paged @@ -373,17 +362,6 @@ func TestGetAttestationsRetries(t *testing.T) { require.Equal(t, len(attestations), 10) bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - // same test as above, but for GetByOwnerAndDigest: - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) - require.NoError(t, err) - - // because we haven't reset the mock, we have added 2 more failed requests - fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 4) - - require.Equal(t, len(attestations), 10) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") } // test total retries @@ -401,7 +379,7 @@ func TestGetAttestationsMaxRetries(t *testing.T) { logger: io.NewTestHandler(), } - _, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + _, err := c.GetByDigest(testFetchParamsWithRepo) require.Error(t, err) fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4) diff --git a/pkg/cmd/attestation/api/mock_client.go b/pkg/cmd/attestation/api/mock_client.go index b2fd334c0..b6062b39f 100644 --- a/pkg/cmd/attestation/api/mock_client.go +++ b/pkg/cmd/attestation/api/mock_client.go @@ -6,58 +6,49 @@ import ( "github.com/cli/cli/v2/pkg/cmd/attestation/test/data" ) +func makeTestAttestation() Attestation { + return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"} +} + type MockClient struct { - OnGetByRepoAndDigest func(repo, digest string, limit int) ([]*Attestation, error) - OnGetByOwnerAndDigest func(owner, digest string, limit int) ([]*Attestation, error) - OnGetTrustDomain func() (string, error) + OnGetByDigest func(params FetchParams) ([]*Attestation, error) + OnGetTrustDomain func() (string, error) } -func (m MockClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) { - return m.OnGetByRepoAndDigest(repo, digest, limit) -} - -func (m MockClient) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) { - return m.OnGetByOwnerAndDigest(owner, digest, limit) +func (m MockClient) GetByDigest(params FetchParams) ([]*Attestation, error) { + return m.OnGetByDigest(params) } func (m MockClient) GetTrustDomain() (string, error) { return m.OnGetTrustDomain() } -func makeTestAttestation() Attestation { - return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"} -} - -func OnGetByRepoAndDigestSuccess(repo, digest string, limit int) ([]*Attestation, error) { +func OnGetByDigestSuccess(params FetchParams) ([]*Attestation, error) { att1 := makeTestAttestation() att2 := makeTestAttestation() - return []*Attestation{&att1, &att2}, nil + attestations := []*Attestation{&att1, &att2} + if params.PredicateType != "" { + return FilterAttestations(params.PredicateType, attestations) + } + + return attestations, nil } -func OnGetByRepoAndDigestFailure(repo, digest string, limit int) ([]*Attestation, error) { - return nil, fmt.Errorf("failed to fetch by repo and digest") -} - -func OnGetByOwnerAndDigestSuccess(owner, digest string, limit int) ([]*Attestation, error) { - att1 := makeTestAttestation() - att2 := makeTestAttestation() - return []*Attestation{&att1, &att2}, nil -} - -func OnGetByOwnerAndDigestFailure(owner, digest string, limit int) ([]*Attestation, error) { - return nil, fmt.Errorf("failed to fetch by owner and digest") +func OnGetByDigestFailure(params FetchParams) ([]*Attestation, error) { + if params.Repo != "" { + return nil, fmt.Errorf("failed to fetch attestations from %s", params.Repo) + } + return nil, fmt.Errorf("failed to fetch attestations from %s", params.Owner) } func NewTestClient() *MockClient { return &MockClient{ - OnGetByRepoAndDigest: OnGetByRepoAndDigestSuccess, - OnGetByOwnerAndDigest: OnGetByOwnerAndDigestSuccess, + OnGetByDigest: OnGetByDigestSuccess, } } func NewFailTestClient() *MockClient { return &MockClient{ - OnGetByRepoAndDigest: OnGetByRepoAndDigestFailure, - OnGetByOwnerAndDigest: OnGetByOwnerAndDigestFailure, + OnGetByDigest: OnGetByDigestFailure, } } diff --git a/pkg/cmd/attestation/download/download.go b/pkg/cmd/attestation/download/download.go index 6913c0787..8d1d1dc05 100644 --- a/pkg/cmd/attestation/download/download.go +++ b/pkg/cmd/attestation/download/download.go @@ -9,7 +9,6 @@ import ( "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" "github.com/cli/cli/v2/pkg/cmd/attestation/auth" "github.com/cli/cli/v2/pkg/cmd/attestation/io" - "github.com/cli/cli/v2/pkg/cmd/attestation/verification" "github.com/cli/cli/v2/pkg/cmdutil" ghauth "github.com/cli/go-gh/v2/pkg/auth" @@ -127,13 +126,16 @@ func runDownload(opts *Options) error { opts.Logger.VerbosePrintf("Downloading trusted metadata for artifact %s\n\n", opts.ArtifactPath) - params := verification.FetchRemoteAttestationsParams{ + if opts.APIClient == nil { + return fmt.Errorf("no APIClient provided") + } + params := api.FetchParams{ Digest: artifact.DigestWithAlg(), Limit: opts.Limit, Owner: opts.Owner, Repo: opts.Repo, } - attestations, err := verification.GetRemoteAttestations(opts.APIClient, params) + attestations, err := opts.APIClient.GetByDigest(params) if err != nil { if errors.Is(err, api.ErrNoAttestationsFound) { fmt.Fprintf(opts.Logger.IO.Out, "No attestations found for %s\n", opts.ArtifactPath) @@ -144,10 +146,9 @@ func runDownload(opts *Options) error { // Apply predicate type filter to returned attestations if opts.PredicateType != "" { - filteredAttestations := verification.FilterAttestations(opts.PredicateType, attestations) - - if len(filteredAttestations) == 0 { - return fmt.Errorf("no attestations found with predicate type: %s", opts.PredicateType) + filteredAttestations, err := api.FilterAttestations(opts.PredicateType, attestations) + if err != nil { + return fmt.Errorf("failed to filter attestations: %v", err) } attestations = filteredAttestations diff --git a/pkg/cmd/attestation/download/download_test.go b/pkg/cmd/attestation/download/download_test.go index ddcd08c92..11872daf9 100644 --- a/pkg/cmd/attestation/download/download_test.go +++ b/pkg/cmd/attestation/download/download_test.go @@ -275,7 +275,7 @@ func TestRunDownload(t *testing.T) { t.Run("no attestations found", func(t *testing.T) { opts := baseOpts opts.APIClient = api.MockClient{ - OnGetByOwnerAndDigest: func(repo, digest string, limit int) ([]*api.Attestation, error) { + OnGetByDigest: func(params api.FetchParams) ([]*api.Attestation, error) { return nil, api.ErrNoAttestationsFound }, } @@ -291,7 +291,7 @@ func TestRunDownload(t *testing.T) { t.Run("failed to fetch attestations", func(t *testing.T) { opts := baseOpts opts.APIClient = api.MockClient{ - OnGetByOwnerAndDigest: func(repo, digest string, limit int) ([]*api.Attestation, error) { + OnGetByDigest: func(params api.FetchParams) ([]*api.Attestation, error) { return nil, fmt.Errorf("failed to fetch attestations") }, } diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index db419ebac..10eb02ac4 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -20,13 +20,6 @@ const SLSAPredicateV1 = "https://slsa.dev/provenance/v1" var ErrUnrecognisedBundleExtension = errors.New("bundle file extension not supported, must be json or jsonl") var ErrEmptyBundleFile = errors.New("provided bundle file is empty") -type FetchRemoteAttestationsParams struct { - Digest string - Limit int - Owner string - Repo string -} - // GetLocalAttestations returns a slice of attestations read from a local bundle file. func GetLocalAttestations(path string) ([]*api.Attestation, error) { var attestations []*api.Attestation @@ -89,28 +82,6 @@ func loadBundlesFromJSONLinesFile(path string) ([]*api.Attestation, error) { return attestations, nil } -func GetRemoteAttestations(client api.Client, params FetchRemoteAttestationsParams) ([]*api.Attestation, error) { - if client == nil { - return nil, fmt.Errorf("api client must be provided") - } - // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. - // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. - if params.Repo != "" { - attestations, err := client.GetByRepoAndDigest(params.Repo, params.Digest, params.Limit) - if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err) - } - return attestations, nil - } else if params.Owner != "" { - attestations, err := client.GetByOwnerAndDigest(params.Owner, params.Digest, params.Limit) - if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err) - } - return attestations, nil - } - return nil, fmt.Errorf("owner or repo must be provided") -} - func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ([]*api.Attestation, error) { attestations, err := client.GetAttestations(artifact.NameRef(), artifact.DigestWithAlg()) if err != nil { @@ -121,31 +92,3 @@ func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ( } return attestations, nil } - -type IntotoStatement struct { - PredicateType string `json:"predicateType"` -} - -func FilterAttestations(predicateType string, attestations []*api.Attestation) []*api.Attestation { - filteredAttestations := []*api.Attestation{} - - for _, each := range attestations { - dsseEnvelope := each.Bundle.GetDsseEnvelope() - if dsseEnvelope != nil { - if dsseEnvelope.PayloadType != "application/vnd.in-toto+json" { - // Don't fail just because an entry isn't intoto - continue - } - var intotoStatement IntotoStatement - if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &intotoStatement); err != nil { - // Don't fail just because a single entry can't be unmarshalled - continue - } - if intotoStatement.PredicateType == predicateType { - filteredAttestations = append(filteredAttestations, each) - } - } - } - - return filteredAttestations -} diff --git a/pkg/cmd/attestation/verification/attestation_test.go b/pkg/cmd/attestation/verification/attestation_test.go index 8acff0c37..6826e2e40 100644 --- a/pkg/cmd/attestation/verification/attestation_test.go +++ b/pkg/cmd/attestation/verification/attestation_test.go @@ -157,10 +157,11 @@ func TestFilterAttestations(t *testing.T) { }, } - filtered := FilterAttestations("https://slsa.dev/provenance/v1", attestations) - + filtered, err := api.FilterAttestations("https://slsa.dev/provenance/v1", attestations) require.Len(t, filtered, 1) + require.NoError(t, err) - filtered = FilterAttestations("NonExistentPredicate", attestations) - require.Len(t, filtered, 0) + filtered, err = api.FilterAttestations("NonExistentPredicate", attestations) + require.Nil(t, filtered) + require.Error(t, err) } diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index bb96c9526..1b98fabf3 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -1,6 +1,7 @@ package verify import ( + "errors" "fmt" "github.com/cli/cli/v2/internal/text" @@ -10,43 +11,63 @@ import ( ) func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) { + // Fetch attestations from GitHub API within this if block since predicate type + // filter is done when the API is called + if o.FetchAttestationsFromGitHubAPI() { + if o.APIClient == nil { + errMsg := "✗ No APIClient provided" + return nil, errMsg, errors.New(errMsg) + } + + params := api.FetchParams{ + Digest: a.DigestWithAlg(), + Limit: o.Limit, + Owner: o.Owner, + PredicateType: o.PredicateType, + Repo: o.Repo, + } + + attestations, err := o.APIClient.GetByDigest(params) + if err != nil { + msg := "✗ Loading attestations from GitHub API failed" + return nil, msg, err + } + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation) + return attestations, msg, nil + } + + // Fetch attestations from local bundle or OCI registry + // Predicate type filtering is done after the attestations are fetched + var attestations []*api.Attestation + var err error + var msg string if o.BundlePath != "" { - attestations, err := verification.GetLocalAttestations(o.BundlePath) + attestations, err = verification.GetLocalAttestations(o.BundlePath) if err != nil { - msg := fmt.Sprintf("✗ Loading attestations from %s failed", a.URL) - return nil, msg, err + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg = fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) + } else { + msg = fmt.Sprintf("Loaded %d attestations from %s", len(attestations), o.BundlePath) } - pluralAttestation := text.Pluralize(len(attestations), "attestation") - msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) - return attestations, msg, nil - } - - if o.UseBundleFromRegistry { - attestations, err := verification.GetOCIAttestations(o.OCIClient, a) + } else if o.UseBundleFromRegistry { + attestations, err = verification.GetOCIAttestations(o.OCIClient, a) if err != nil { - msg := "✗ Loading attestations from OCI registry failed" - return nil, msg, err + msg = "✗ Loading attestations from OCI registry failed" + } else { + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg = fmt.Sprintf("Loaded %s from OCI registry", pluralAttestation) } - pluralAttestation := text.Pluralize(len(attestations), "attestation") - msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.ArtifactPath) - return attestations, msg, nil } - - params := verification.FetchRemoteAttestationsParams{ - Digest: a.DigestWithAlg(), - Limit: o.Limit, - Owner: o.Owner, - Repo: o.Repo, - } - - attestations, err := verification.GetRemoteAttestations(o.APIClient, params) if err != nil { - msg := "✗ Loading attestations from GitHub API failed" return nil, msg, err } - pluralAttestation := text.Pluralize(len(attestations), "attestation") - msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation) - return attestations, msg, nil + + filtered, err := api.FilterAttestations(o.PredicateType, attestations) + if err != nil { + return nil, err.Error(), err + } + return filtered, msg, nil } func verifyAttestations(art artifact.DigestedArtifact, att []*api.Attestation, sgVerifier verification.SigstoreVerifier, ec verification.EnforcementCriteria) ([]*verification.AttestationProcessingResult, string, error) { diff --git a/pkg/cmd/attestation/verify/attestation_test.go b/pkg/cmd/attestation/verify/attestation_test.go new file mode 100644 index 000000000..f015805ae --- /dev/null +++ b/pkg/cmd/attestation/verify/attestation_test.go @@ -0,0 +1,71 @@ +package verify + +import ( + "testing" + + "github.com/cli/cli/v2/pkg/cmd/attestation/api" + "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" + "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" + "github.com/cli/cli/v2/pkg/cmd/attestation/verification" + "github.com/stretchr/testify/require" +) + +func TestGetAttestations_OCIRegistry_PredicateTypeFiltering(t *testing.T) { + artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256") + require.NoError(t, err) + + o := &Options{ + OCIClient: oci.MockClient{}, + PredicateType: verification.SLSAPredicateV1, + Repo: "cli/cli", + UseBundleFromRegistry: true, + } + attestations, msg, err := getAttestations(o, *artifact) + require.NoError(t, err) + require.Contains(t, msg, "Loaded 2 attestations from OCI registry") + require.Len(t, attestations, 2) + + o.PredicateType = "custom predicate type" + attestations, msg, err = getAttestations(o, *artifact) + require.Error(t, err) + require.Contains(t, msg, "no attestations found with predicate type") + require.Nil(t, attestations) +} + +func TestGetAttestations_LocalBundle_PredicateTypeFiltering(t *testing.T) { + artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256") + require.NoError(t, err) + + o := &Options{ + BundlePath: "../test/data/sigstore-js-2.1.0-bundle.json", + PredicateType: verification.SLSAPredicateV1, + Repo: "sigstore/sigstore-js", + } + attestations, _, err := getAttestations(o, *artifact) + require.NoError(t, err) + require.Len(t, attestations, 1) + + o.PredicateType = "custom predicate type" + attestations, _, err = getAttestations(o, *artifact) + require.Error(t, err) + require.Nil(t, attestations) +} + +func TestGetAttestations_GhAPI_NoAttestationsFound(t *testing.T) { + artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256") + require.NoError(t, err) + + o := &Options{ + APIClient: api.NewTestClient(), + PredicateType: verification.SLSAPredicateV1, + Repo: "sigstore/sigstore-js", + } + attestations, _, err := getAttestations(o, *artifact) + require.NoError(t, err) + require.Len(t, attestations, 2) + + o.PredicateType = "custom predicate type" + attestations, _, err = getAttestations(o, *artifact) + require.Error(t, err) + require.Nil(t, attestations) +} diff --git a/pkg/cmd/attestation/verify/options.go b/pkg/cmd/attestation/verify/options.go index 0fbbec55a..e47c4f4a8 100644 --- a/pkg/cmd/attestation/verify/options.go +++ b/pkg/cmd/attestation/verify/options.go @@ -53,6 +53,12 @@ func (opts *Options) Clean() { } } +// FetchAttestationsFromGitHubAPI returns true if the command should fetch attestations from the GitHub API +// It checks that a bundle path is not provided and that the "use bundle from registry" flag is not set +func (opts *Options) FetchAttestationsFromGitHubAPI() bool { + return opts.BundlePath == "" && !opts.UseBundleFromRegistry +} + // AreFlagsValid checks that the provided flag combination is valid // and returns an error otherwise func (opts *Options) AreFlagsValid() error { diff --git a/pkg/cmd/attestation/verify/verify.go b/pkg/cmd/attestation/verify/verify.go index b3bad519a..b8debc529 100644 --- a/pkg/cmd/attestation/verify/verify.go +++ b/pkg/cmd/attestation/verify/verify.go @@ -288,14 +288,6 @@ func runVerify(opts *Options) error { // Print the message signifying success fetching attestations opts.Logger.Println(logMsg) - // Apply predicate type filter to returned attestations - filteredAttestations := verification.FilterAttestations(ec.PredicateType, attestations) - if len(filteredAttestations) == 0 { - opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ No attestations found with predicate type: %s\n"), opts.PredicateType) - return fmt.Errorf("no matching predicate found") - } - attestations = filteredAttestations - // print information about the policy that will be enforced against attestations opts.Logger.Println("\nThe following policy criteria will be enforced:") opts.Logger.Println(ec.BuildPolicyInformation()) diff --git a/pkg/cmd/attestation/verify/verify_test.go b/pkg/cmd/attestation/verify/verify_test.go index 092a009d8..2b821a435 100644 --- a/pkg/cmd/attestation/verify/verify_test.go +++ b/pkg/cmd/attestation/verify/verify_test.go @@ -510,7 +510,7 @@ func TestRunVerify(t *testing.T) { err := runVerify(&customOpts) require.Error(t, err) - require.ErrorContains(t, err, "no matching predicate found") + require.ErrorContains(t, err, "no attestations found with predicate type") }) t.Run("with valid OCI artifact with UseBundleFromRegistry flag but no bundle return from registry", func(t *testing.T) { diff --git a/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh b/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh index 647a13a4c..cea3c7228 100644 --- a/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh +++ b/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh @@ -14,3 +14,9 @@ if ! $ghBuildPath attestation verify "$ghCLIArtifact" --digest-alg=sha256 --owne echo "Failed to verify" exit 1 fi + +# Try to verify when specifying a predicate type that does not match the attestation +if $ghBuildPath attestation verify "$ghCLIArtifact" --digest-alg=sha256 --owner=cli --predicate-type=my-custom-predicate-type; then + echo "Verification should have failed" + exit 1 +fi