reorganize getAttestations func to check for remote gh api fetching first

Signed-off-by: Meredith Lancaster <malancas@github.com>
This commit is contained in:
Meredith Lancaster 2025-03-24 17:28:50 -06:00
parent a78c06970a
commit faef81f4bc
4 changed files with 35 additions and 27 deletions

View file

@ -275,7 +275,7 @@ func TestRunDownload(t *testing.T) {
t.Run("no attestations found", func(t *testing.T) {
opts := baseOpts
opts.APIClient = api.MockClient{
OnGetByOwnerAndDigest: func(repo, digest string, limit int) ([]*api.Attestation, error) {
OnGetByOwnerAndDigest: func(repo, digest, predicateType string, limit int) ([]*api.Attestation, error) {
return nil, api.ErrNoAttestationsFound
},
}
@ -291,7 +291,7 @@ func TestRunDownload(t *testing.T) {
t.Run("failed to fetch attestations", func(t *testing.T) {
opts := baseOpts
opts.APIClient = api.MockClient{
OnGetByOwnerAndDigest: func(repo, digest string, limit int) ([]*api.Attestation, error) {
OnGetByOwnerAndDigest: func(repo, digest, predicateType string, limit int) ([]*api.Attestation, error) {
return nil, fmt.Errorf("failed to fetch attestations")
},
}

View file

@ -21,10 +21,11 @@ var ErrUnrecognisedBundleExtension = errors.New("bundle file extension not suppo
var ErrEmptyBundleFile = errors.New("provided bundle file is empty")
type FetchRemoteAttestationsParams struct {
Digest string
Limit int
Owner string
Repo string
Digest string
Limit int
Owner string
PredicateType string
Repo string
}
// GetLocalAttestations returns a slice of attestations read from a local bundle file.
@ -96,13 +97,13 @@ func GetRemoteAttestations(client api.Client, params FetchRemoteAttestationsPara
// check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo.
// If Repo is not set, the field will remain empty. It will not be populated using the value of Owner.
if params.Repo != "" {
attestations, err := client.GetByRepoAndDigest(params.Repo, params.Digest, params.Limit)
attestations, err := client.GetByRepoAndDigest(params.Repo, params.Digest, params.PredicateType, params.Limit)
if err != nil {
return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err)
}
return attestations, nil
} else if params.Owner != "" {
attestations, err := client.GetByOwnerAndDigest(params.Owner, params.Digest, params.Limit)
attestations, err := client.GetByOwnerAndDigest(params.Owner, params.Digest, params.PredicateType, params.Limit)
if err != nil {
return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err)
}

View file

@ -10,7 +10,24 @@ import (
)
func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) {
if o.BundlePath != "" {
if o.FetchAttestationsFromGitHubAPI() {
params := verification.FetchRemoteAttestationsParams{
Digest: a.DigestWithAlg(),
Limit: o.Limit,
Owner: o.Owner,
PredicateType: o.PredicateType,
Repo: o.Repo,
}
attestations, err := verification.GetRemoteAttestations(o.APIClient, params)
if err != nil {
msg := "✗ Loading attestations from GitHub API failed"
return nil, msg, err
}
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation)
return attestations, msg, nil
} else if o.BundlePath != "" {
attestations, err := verification.GetLocalAttestations(o.BundlePath)
if err != nil {
msg := fmt.Sprintf("✗ Loading attestations from %s failed", a.URL)
@ -19,9 +36,7 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath)
return attestations, msg, nil
}
if o.UseBundleFromRegistry {
} else if o.UseBundleFromRegistry {
attestations, err := verification.GetOCIAttestations(o.OCIClient, a)
if err != nil {
msg := "✗ Loading attestations from OCI registry failed"
@ -32,21 +47,7 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio
return attestations, msg, nil
}
params := verification.FetchRemoteAttestationsParams{
Digest: a.DigestWithAlg(),
Limit: o.Limit,
Owner: o.Owner,
Repo: o.Repo,
}
attestations, err := verification.GetRemoteAttestations(o.APIClient, params)
if err != nil {
msg := "✗ Loading attestations from GitHub API failed"
return nil, msg, err
}
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation)
return attestations, msg, nil
return nil, "", fmt.Errorf("no valid attestation source provided")
}
func verifyAttestations(art artifact.DigestedArtifact, att []*api.Attestation, sgVerifier verification.SigstoreVerifier, ec verification.EnforcementCriteria) ([]*verification.AttestationProcessingResult, string, error) {

View file

@ -53,6 +53,12 @@ func (opts *Options) Clean() {
}
}
// FetchAttestationsFromGitHubAPI returns true if the command should fetch attestations from the GitHub API
// It checks that a bundle path is not provided and that the "use bundle from registry" flag is not set
func (opts *Options) FetchAttestationsFromGitHubAPI() bool {
return opts.BundlePath == "" && !opts.UseBundleFromRegistry
}
// AreFlagsValid checks that the provided flag combination is valid
// and returns an error otherwise
func (opts *Options) AreFlagsValid() error {