Merge pull request #10670 from malancas/move-predicate-type-filtering

Move predicate type filtering in `gh attestation verify`
This commit is contained in:
Meredith Lancaster 2025-05-06 11:08:34 -06:00 committed by GitHub
commit 315876852a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 338 additions and 265 deletions

View file

@ -1,7 +1,10 @@
package api
import (
"encoding/json"
"errors"
"fmt"
"github.com/sigstore/sigstore-go/pkg/bundle"
)
@ -20,3 +23,35 @@ type Attestation struct {
type AttestationsResponse struct {
Attestations []*Attestation `json:"attestations"`
}
type IntotoStatement struct {
PredicateType string `json:"predicateType"`
}
func FilterAttestations(predicateType string, attestations []*Attestation) ([]*Attestation, error) {
filteredAttestations := []*Attestation{}
for _, each := range attestations {
dsseEnvelope := each.Bundle.GetDsseEnvelope()
if dsseEnvelope != nil {
if dsseEnvelope.PayloadType != "application/vnd.in-toto+json" {
// Don't fail just because an entry isn't intoto
continue
}
var intotoStatement IntotoStatement
if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &intotoStatement); err != nil {
// Don't fail just because a single entry can't be unmarshalled
continue
}
if intotoStatement.PredicateType == predicateType {
filteredAttestations = append(filteredAttestations, each)
}
}
}
if len(filteredAttestations) == 0 {
return nil, fmt.Errorf("no attestations found with predicate type: %s", predicateType)
}
return filteredAttestations, nil
}

View file

@ -27,6 +27,28 @@ const (
// Allow injecting backoff interval in tests.
var getAttestationRetryInterval = time.Millisecond * 200
// FetchParams are the parameters for fetching attestations from the GitHub API
type FetchParams struct {
Digest string
Limit int
Owner string
PredicateType string
Repo string
}
func (p *FetchParams) Validate() error {
if p.Digest == "" {
return fmt.Errorf("digest must be provided")
}
if p.Limit <= 0 || p.Limit > maxLimitForFlag {
return fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag)
}
if p.Repo == "" && p.Owner == "" {
return fmt.Errorf("owner or repo must be provided")
}
return nil
}
// githubApiClient makes REST calls to the GitHub API
type githubApiClient interface {
REST(hostname, method, p string, body io.Reader, data interface{}) error
@ -39,8 +61,7 @@ type httpClient interface {
}
type Client interface {
GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error)
GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error)
GetByDigest(params FetchParams) ([]*Attestation, error)
GetTrustDomain() (string, error)
}
@ -60,22 +81,11 @@ func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClien
}
}
// GetByRepoAndDigest fetches the attestation by repo and digest
func (c *LiveClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) {
c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest)
url := fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, repo, digest)
return c.getByURL(url, limit)
}
// GetByOwnerAndDigest fetches attestation by owner and digest
func (c *LiveClient) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) {
c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest)
url := fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, owner, digest)
return c.getByURL(url, limit)
}
func (c *LiveClient) getByURL(url string, limit int) ([]*Attestation, error) {
attestations, err := c.getAttestations(url, limit)
// GetByDigest fetches the attestation by digest and either owner or repo
// depending on which is provided
func (c *LiveClient) GetByDigest(params FetchParams) ([]*Attestation, error) {
c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", params.Digest)
attestations, err := c.getAttestations(params)
if err != nil {
return nil, err
}
@ -88,40 +98,52 @@ func (c *LiveClient) getByURL(url string, limit int) ([]*Attestation, error) {
return bundles, nil
}
// GetTrustDomain returns the current trust domain. If the default is used
// the empty string is returned
func (c *LiveClient) GetTrustDomain() (string, error) {
return c.getTrustDomain(MetaPath)
}
func (c *LiveClient) getAttestations(url string, limit int) ([]*Attestation, error) {
perPage := limit
if perPage <= 0 || perPage > maxLimitForFlag {
return nil, fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag)
func (c *LiveClient) buildRequestURL(params FetchParams) (string, error) {
if err := params.Validate(); err != nil {
return "", err
}
var url string
if params.Repo != "" {
// check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo.
// If Repo is not set, the field will remain empty. It will not be populated using the value of Owner.
url = fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, params.Repo, params.Digest)
} else {
url = fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, params.Owner, params.Digest)
}
perPage := params.Limit
if perPage > maxLimitForFetch {
perPage = maxLimitForFetch
}
// ref: https://github.com/cli/go-gh/blob/d32c104a9a25c9de3d7c7b07a43ae0091441c858/example_gh_test.go#L96
url = fmt.Sprintf("%s?per_page=%d", url, perPage)
if params.PredicateType != "" {
url = fmt.Sprintf("%s&predicate_type=%s", url, params.PredicateType)
}
return url, nil
}
func (c *LiveClient) getAttestations(params FetchParams) ([]*Attestation, error) {
url, err := c.buildRequestURL(params)
if err != nil {
return nil, err
}
var attestations []*Attestation
var resp AttestationsResponse
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)
// if no attestation or less than limit, then keep fetching
for url != "" && len(attestations) < limit {
for url != "" && len(attestations) < params.Limit {
err := backoff.Retry(func() error {
newURL, restErr := c.githubAPI.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
if restErr != nil {
if shouldRetry(restErr) {
return restErr
} else {
return backoff.Permanent(restErr)
}
return backoff.Permanent(restErr)
}
url = newURL
@ -140,8 +162,8 @@ func (c *LiveClient) getAttestations(url string, limit int) ([]*Attestation, err
return nil, ErrNoAttestationsFound
}
if len(attestations) > limit {
return attestations[:limit], nil
if len(attestations) > params.Limit {
return attestations[:params.Limit], nil
}
return attestations, nil
@ -241,6 +263,12 @@ func shouldRetry(err error) bool {
return false
}
// GetTrustDomain returns the current trust domain. If the default is used
// the empty string is returned
func (c *LiveClient) GetTrustDomain() (string, error) {
return c.getTrustDomain(MetaPath)
}
func (c *LiveClient) getTrustDomain(url string) (string, error) {
var resp MetaResponse

View file

@ -42,78 +42,75 @@ func NewClientWithMockGHClient(hasNextPage bool) Client {
}
}
var testFetchParamsWithOwner = FetchParams{
Digest: testDigest,
Limit: DefaultLimit,
Owner: testOwner,
PredicateType: "https://slsa.dev/provenance/v1",
}
var testFetchParamsWithRepo = FetchParams{
Digest: testDigest,
Limit: DefaultLimit,
Repo: testRepo,
PredicateType: "https://slsa.dev/provenance/v1",
}
type getByTestCase struct {
name string
params FetchParams
limit int
expectedAttestations int
hasNextPage bool
}
var getByTestCases = []getByTestCase{
{
name: "get by digest with owner",
params: testFetchParamsWithOwner,
expectedAttestations: 5,
},
{
name: "get by digest with repo",
params: testFetchParamsWithRepo,
expectedAttestations: 5,
},
{
name: "get by digest with attestations greater than limit",
params: testFetchParamsWithRepo,
limit: 3,
expectedAttestations: 3,
},
{
name: "get by digest with next page",
params: testFetchParamsWithRepo,
expectedAttestations: 10,
hasNextPage: true,
},
{
name: "greater than limit with next page",
params: testFetchParamsWithRepo,
limit: 7,
expectedAttestations: 7,
hasNextPage: true,
},
}
func TestGetByDigest(t *testing.T) {
c := NewClientWithMockGHClient(false)
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
require.NoError(t, err)
for _, tc := range getByTestCases {
t.Run(tc.name, func(t *testing.T) {
c := NewClientWithMockGHClient(tc.hasNextPage)
require.Equal(t, 5, len(attestations))
bundle := (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
if tc.limit > 0 {
tc.params.Limit = tc.limit
}
attestations, err := c.GetByDigest(tc.params)
require.NoError(t, err)
attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit)
require.NoError(t, err)
require.Equal(t, 5, len(attestations))
bundle = (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
}
func TestGetByDigestGreaterThanLimit(t *testing.T) {
c := NewClientWithMockGHClient(false)
limit := 3
// The method should return five results when the limit is not set
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, limit)
require.NoError(t, err)
require.Equal(t, 3, len(attestations))
bundle := (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, limit)
require.NoError(t, err)
require.Equal(t, len(attestations), limit)
bundle = (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
}
func TestGetByDigestWithNextPage(t *testing.T) {
c := NewClientWithMockGHClient(true)
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
require.NoError(t, err)
require.Equal(t, len(attestations), 10)
bundle := (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit)
require.NoError(t, err)
require.Equal(t, len(attestations), 10)
bundle = (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
}
func TestGetByDigestGreaterThanLimitWithNextPage(t *testing.T) {
c := NewClientWithMockGHClient(true)
limit := 7
// The method should return five results when the limit is not set
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, limit)
require.NoError(t, err)
require.Equal(t, len(attestations), limit)
bundle := (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, limit)
require.NoError(t, err)
require.Equal(t, len(attestations), limit)
bundle = (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
require.Equal(t, tc.expectedAttestations, len(attestations))
bundle := (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
})
}
}
func TestGetByDigest_NoAttestationsFound(t *testing.T) {
@ -130,12 +127,7 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) {
logger: io.NewTestHandler(),
}
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
require.Error(t, err)
require.IsType(t, ErrNoAttestationsFound, err)
require.Nil(t, attestations)
attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit)
attestations, err := c.GetByDigest(testFetchParamsWithRepo)
require.Error(t, err)
require.IsType(t, ErrNoAttestationsFound, err)
require.Nil(t, attestations)
@ -153,11 +145,7 @@ func TestGetByDigest_Error(t *testing.T) {
logger: io.NewTestHandler(),
}
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
require.Error(t, err)
require.Nil(t, attestations)
attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit)
attestations, err := c.GetByDigest(testFetchParamsWithRepo)
require.Error(t, err)
require.Nil(t, attestations)
}
@ -362,7 +350,8 @@ func TestGetAttestationsRetries(t *testing.T) {
logger: io.NewTestHandler(),
}
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
testFetchParamsWithRepo.Limit = 30
attestations, err := c.GetByDigest(testFetchParamsWithRepo)
require.NoError(t, err)
// assert the error path was executed; because this is a paged
@ -373,17 +362,6 @@ func TestGetAttestationsRetries(t *testing.T) {
require.Equal(t, len(attestations), 10)
bundle := (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
// same test as above, but for GetByOwnerAndDigest:
attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit)
require.NoError(t, err)
// because we haven't reset the mock, we have added 2 more failed requests
fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 4)
require.Equal(t, len(attestations), 10)
bundle = (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
}
// test total retries
@ -401,7 +379,7 @@ func TestGetAttestationsMaxRetries(t *testing.T) {
logger: io.NewTestHandler(),
}
_, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
_, err := c.GetByDigest(testFetchParamsWithRepo)
require.Error(t, err)
fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4)

View file

@ -6,58 +6,49 @@ import (
"github.com/cli/cli/v2/pkg/cmd/attestation/test/data"
)
func makeTestAttestation() Attestation {
return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"}
}
type MockClient struct {
OnGetByRepoAndDigest func(repo, digest string, limit int) ([]*Attestation, error)
OnGetByOwnerAndDigest func(owner, digest string, limit int) ([]*Attestation, error)
OnGetTrustDomain func() (string, error)
OnGetByDigest func(params FetchParams) ([]*Attestation, error)
OnGetTrustDomain func() (string, error)
}
func (m MockClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) {
return m.OnGetByRepoAndDigest(repo, digest, limit)
}
func (m MockClient) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) {
return m.OnGetByOwnerAndDigest(owner, digest, limit)
func (m MockClient) GetByDigest(params FetchParams) ([]*Attestation, error) {
return m.OnGetByDigest(params)
}
func (m MockClient) GetTrustDomain() (string, error) {
return m.OnGetTrustDomain()
}
func makeTestAttestation() Attestation {
return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"}
}
func OnGetByRepoAndDigestSuccess(repo, digest string, limit int) ([]*Attestation, error) {
func OnGetByDigestSuccess(params FetchParams) ([]*Attestation, error) {
att1 := makeTestAttestation()
att2 := makeTestAttestation()
return []*Attestation{&att1, &att2}, nil
attestations := []*Attestation{&att1, &att2}
if params.PredicateType != "" {
return FilterAttestations(params.PredicateType, attestations)
}
return attestations, nil
}
func OnGetByRepoAndDigestFailure(repo, digest string, limit int) ([]*Attestation, error) {
return nil, fmt.Errorf("failed to fetch by repo and digest")
}
func OnGetByOwnerAndDigestSuccess(owner, digest string, limit int) ([]*Attestation, error) {
att1 := makeTestAttestation()
att2 := makeTestAttestation()
return []*Attestation{&att1, &att2}, nil
}
func OnGetByOwnerAndDigestFailure(owner, digest string, limit int) ([]*Attestation, error) {
return nil, fmt.Errorf("failed to fetch by owner and digest")
func OnGetByDigestFailure(params FetchParams) ([]*Attestation, error) {
if params.Repo != "" {
return nil, fmt.Errorf("failed to fetch attestations from %s", params.Repo)
}
return nil, fmt.Errorf("failed to fetch attestations from %s", params.Owner)
}
func NewTestClient() *MockClient {
return &MockClient{
OnGetByRepoAndDigest: OnGetByRepoAndDigestSuccess,
OnGetByOwnerAndDigest: OnGetByOwnerAndDigestSuccess,
OnGetByDigest: OnGetByDigestSuccess,
}
}
func NewFailTestClient() *MockClient {
return &MockClient{
OnGetByRepoAndDigest: OnGetByRepoAndDigestFailure,
OnGetByOwnerAndDigest: OnGetByOwnerAndDigestFailure,
OnGetByDigest: OnGetByDigestFailure,
}
}

View file

@ -9,7 +9,6 @@ import (
"github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci"
"github.com/cli/cli/v2/pkg/cmd/attestation/auth"
"github.com/cli/cli/v2/pkg/cmd/attestation/io"
"github.com/cli/cli/v2/pkg/cmd/attestation/verification"
"github.com/cli/cli/v2/pkg/cmdutil"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
@ -127,13 +126,16 @@ func runDownload(opts *Options) error {
opts.Logger.VerbosePrintf("Downloading trusted metadata for artifact %s\n\n", opts.ArtifactPath)
params := verification.FetchRemoteAttestationsParams{
if opts.APIClient == nil {
return fmt.Errorf("no APIClient provided")
}
params := api.FetchParams{
Digest: artifact.DigestWithAlg(),
Limit: opts.Limit,
Owner: opts.Owner,
Repo: opts.Repo,
}
attestations, err := verification.GetRemoteAttestations(opts.APIClient, params)
attestations, err := opts.APIClient.GetByDigest(params)
if err != nil {
if errors.Is(err, api.ErrNoAttestationsFound) {
fmt.Fprintf(opts.Logger.IO.Out, "No attestations found for %s\n", opts.ArtifactPath)
@ -144,10 +146,9 @@ func runDownload(opts *Options) error {
// Apply predicate type filter to returned attestations
if opts.PredicateType != "" {
filteredAttestations := verification.FilterAttestations(opts.PredicateType, attestations)
if len(filteredAttestations) == 0 {
return fmt.Errorf("no attestations found with predicate type: %s", opts.PredicateType)
filteredAttestations, err := api.FilterAttestations(opts.PredicateType, attestations)
if err != nil {
return fmt.Errorf("failed to filter attestations: %v", err)
}
attestations = filteredAttestations

View file

@ -275,7 +275,7 @@ func TestRunDownload(t *testing.T) {
t.Run("no attestations found", func(t *testing.T) {
opts := baseOpts
opts.APIClient = api.MockClient{
OnGetByOwnerAndDigest: func(repo, digest string, limit int) ([]*api.Attestation, error) {
OnGetByDigest: func(params api.FetchParams) ([]*api.Attestation, error) {
return nil, api.ErrNoAttestationsFound
},
}
@ -291,7 +291,7 @@ func TestRunDownload(t *testing.T) {
t.Run("failed to fetch attestations", func(t *testing.T) {
opts := baseOpts
opts.APIClient = api.MockClient{
OnGetByOwnerAndDigest: func(repo, digest string, limit int) ([]*api.Attestation, error) {
OnGetByDigest: func(params api.FetchParams) ([]*api.Attestation, error) {
return nil, fmt.Errorf("failed to fetch attestations")
},
}

View file

@ -20,13 +20,6 @@ const SLSAPredicateV1 = "https://slsa.dev/provenance/v1"
var ErrUnrecognisedBundleExtension = errors.New("bundle file extension not supported, must be json or jsonl")
var ErrEmptyBundleFile = errors.New("provided bundle file is empty")
type FetchRemoteAttestationsParams struct {
Digest string
Limit int
Owner string
Repo string
}
// GetLocalAttestations returns a slice of attestations read from a local bundle file.
func GetLocalAttestations(path string) ([]*api.Attestation, error) {
var attestations []*api.Attestation
@ -89,28 +82,6 @@ func loadBundlesFromJSONLinesFile(path string) ([]*api.Attestation, error) {
return attestations, nil
}
func GetRemoteAttestations(client api.Client, params FetchRemoteAttestationsParams) ([]*api.Attestation, error) {
if client == nil {
return nil, fmt.Errorf("api client must be provided")
}
// check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo.
// If Repo is not set, the field will remain empty. It will not be populated using the value of Owner.
if params.Repo != "" {
attestations, err := client.GetByRepoAndDigest(params.Repo, params.Digest, params.Limit)
if err != nil {
return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err)
}
return attestations, nil
} else if params.Owner != "" {
attestations, err := client.GetByOwnerAndDigest(params.Owner, params.Digest, params.Limit)
if err != nil {
return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err)
}
return attestations, nil
}
return nil, fmt.Errorf("owner or repo must be provided")
}
func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ([]*api.Attestation, error) {
attestations, err := client.GetAttestations(artifact.NameRef(), artifact.DigestWithAlg())
if err != nil {
@ -121,31 +92,3 @@ func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) (
}
return attestations, nil
}
type IntotoStatement struct {
PredicateType string `json:"predicateType"`
}
func FilterAttestations(predicateType string, attestations []*api.Attestation) []*api.Attestation {
filteredAttestations := []*api.Attestation{}
for _, each := range attestations {
dsseEnvelope := each.Bundle.GetDsseEnvelope()
if dsseEnvelope != nil {
if dsseEnvelope.PayloadType != "application/vnd.in-toto+json" {
// Don't fail just because an entry isn't intoto
continue
}
var intotoStatement IntotoStatement
if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &intotoStatement); err != nil {
// Don't fail just because a single entry can't be unmarshalled
continue
}
if intotoStatement.PredicateType == predicateType {
filteredAttestations = append(filteredAttestations, each)
}
}
}
return filteredAttestations
}

View file

@ -157,10 +157,11 @@ func TestFilterAttestations(t *testing.T) {
},
}
filtered := FilterAttestations("https://slsa.dev/provenance/v1", attestations)
filtered, err := api.FilterAttestations("https://slsa.dev/provenance/v1", attestations)
require.Len(t, filtered, 1)
require.NoError(t, err)
filtered = FilterAttestations("NonExistentPredicate", attestations)
require.Len(t, filtered, 0)
filtered, err = api.FilterAttestations("NonExistentPredicate", attestations)
require.Nil(t, filtered)
require.Error(t, err)
}

View file

@ -1,6 +1,7 @@
package verify
import (
"errors"
"fmt"
"github.com/cli/cli/v2/internal/text"
@ -10,43 +11,63 @@ import (
)
func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) {
// Fetch attestations from GitHub API within this if block since predicate type
// filter is done when the API is called
if o.FetchAttestationsFromGitHubAPI() {
if o.APIClient == nil {
errMsg := "✗ No APIClient provided"
return nil, errMsg, errors.New(errMsg)
}
params := api.FetchParams{
Digest: a.DigestWithAlg(),
Limit: o.Limit,
Owner: o.Owner,
PredicateType: o.PredicateType,
Repo: o.Repo,
}
attestations, err := o.APIClient.GetByDigest(params)
if err != nil {
msg := "✗ Loading attestations from GitHub API failed"
return nil, msg, err
}
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation)
return attestations, msg, nil
}
// Fetch attestations from local bundle or OCI registry
// Predicate type filtering is done after the attestations are fetched
var attestations []*api.Attestation
var err error
var msg string
if o.BundlePath != "" {
attestations, err := verification.GetLocalAttestations(o.BundlePath)
attestations, err = verification.GetLocalAttestations(o.BundlePath)
if err != nil {
msg := fmt.Sprintf("✗ Loading attestations from %s failed", a.URL)
return nil, msg, err
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg = fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath)
} else {
msg = fmt.Sprintf("Loaded %d attestations from %s", len(attestations), o.BundlePath)
}
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath)
return attestations, msg, nil
}
if o.UseBundleFromRegistry {
attestations, err := verification.GetOCIAttestations(o.OCIClient, a)
} else if o.UseBundleFromRegistry {
attestations, err = verification.GetOCIAttestations(o.OCIClient, a)
if err != nil {
msg := "✗ Loading attestations from OCI registry failed"
return nil, msg, err
msg = "✗ Loading attestations from OCI registry failed"
} else {
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg = fmt.Sprintf("Loaded %s from OCI registry", pluralAttestation)
}
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.ArtifactPath)
return attestations, msg, nil
}
params := verification.FetchRemoteAttestationsParams{
Digest: a.DigestWithAlg(),
Limit: o.Limit,
Owner: o.Owner,
Repo: o.Repo,
}
attestations, err := verification.GetRemoteAttestations(o.APIClient, params)
if err != nil {
msg := "✗ Loading attestations from GitHub API failed"
return nil, msg, err
}
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation)
return attestations, msg, nil
filtered, err := api.FilterAttestations(o.PredicateType, attestations)
if err != nil {
return nil, err.Error(), err
}
return filtered, msg, nil
}
func verifyAttestations(art artifact.DigestedArtifact, att []*api.Attestation, sgVerifier verification.SigstoreVerifier, ec verification.EnforcementCriteria) ([]*verification.AttestationProcessingResult, string, error) {

View file

@ -0,0 +1,71 @@
package verify
import (
"testing"
"github.com/cli/cli/v2/pkg/cmd/attestation/api"
"github.com/cli/cli/v2/pkg/cmd/attestation/artifact"
"github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci"
"github.com/cli/cli/v2/pkg/cmd/attestation/verification"
"github.com/stretchr/testify/require"
)
func TestGetAttestations_OCIRegistry_PredicateTypeFiltering(t *testing.T) {
artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256")
require.NoError(t, err)
o := &Options{
OCIClient: oci.MockClient{},
PredicateType: verification.SLSAPredicateV1,
Repo: "cli/cli",
UseBundleFromRegistry: true,
}
attestations, msg, err := getAttestations(o, *artifact)
require.NoError(t, err)
require.Contains(t, msg, "Loaded 2 attestations from OCI registry")
require.Len(t, attestations, 2)
o.PredicateType = "custom predicate type"
attestations, msg, err = getAttestations(o, *artifact)
require.Error(t, err)
require.Contains(t, msg, "no attestations found with predicate type")
require.Nil(t, attestations)
}
func TestGetAttestations_LocalBundle_PredicateTypeFiltering(t *testing.T) {
artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256")
require.NoError(t, err)
o := &Options{
BundlePath: "../test/data/sigstore-js-2.1.0-bundle.json",
PredicateType: verification.SLSAPredicateV1,
Repo: "sigstore/sigstore-js",
}
attestations, _, err := getAttestations(o, *artifact)
require.NoError(t, err)
require.Len(t, attestations, 1)
o.PredicateType = "custom predicate type"
attestations, _, err = getAttestations(o, *artifact)
require.Error(t, err)
require.Nil(t, attestations)
}
func TestGetAttestations_GhAPI_NoAttestationsFound(t *testing.T) {
artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256")
require.NoError(t, err)
o := &Options{
APIClient: api.NewTestClient(),
PredicateType: verification.SLSAPredicateV1,
Repo: "sigstore/sigstore-js",
}
attestations, _, err := getAttestations(o, *artifact)
require.NoError(t, err)
require.Len(t, attestations, 2)
o.PredicateType = "custom predicate type"
attestations, _, err = getAttestations(o, *artifact)
require.Error(t, err)
require.Nil(t, attestations)
}

View file

@ -53,6 +53,12 @@ func (opts *Options) Clean() {
}
}
// FetchAttestationsFromGitHubAPI returns true if the command should fetch attestations from the GitHub API
// It checks that a bundle path is not provided and that the "use bundle from registry" flag is not set
func (opts *Options) FetchAttestationsFromGitHubAPI() bool {
return opts.BundlePath == "" && !opts.UseBundleFromRegistry
}
// AreFlagsValid checks that the provided flag combination is valid
// and returns an error otherwise
func (opts *Options) AreFlagsValid() error {

View file

@ -288,14 +288,6 @@ func runVerify(opts *Options) error {
// Print the message signifying success fetching attestations
opts.Logger.Println(logMsg)
// Apply predicate type filter to returned attestations
filteredAttestations := verification.FilterAttestations(ec.PredicateType, attestations)
if len(filteredAttestations) == 0 {
opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ No attestations found with predicate type: %s\n"), opts.PredicateType)
return fmt.Errorf("no matching predicate found")
}
attestations = filteredAttestations
// print information about the policy that will be enforced against attestations
opts.Logger.Println("\nThe following policy criteria will be enforced:")
opts.Logger.Println(ec.BuildPolicyInformation())

View file

@ -510,7 +510,7 @@ func TestRunVerify(t *testing.T) {
err := runVerify(&customOpts)
require.Error(t, err)
require.ErrorContains(t, err, "no matching predicate found")
require.ErrorContains(t, err, "no attestations found with predicate type")
})
t.Run("with valid OCI artifact with UseBundleFromRegistry flag but no bundle return from registry", func(t *testing.T) {

View file

@ -14,3 +14,9 @@ if ! $ghBuildPath attestation verify "$ghCLIArtifact" --digest-alg=sha256 --owne
echo "Failed to verify"
exit 1
fi
# Try to verify when specifying a predicate type that does not match the attestation
if $ghBuildPath attestation verify "$ghCLIArtifact" --digest-alg=sha256 --owner=cli --predicate-type=my-custom-predicate-type; then
echo "Verification should have failed"
exit 1
fi