diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index 55fcf1f7e..5774170d9 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -9,6 +9,16 @@ import ( "github.com/cli/cli/v2/pkg/cmd/attestation/verification" ) +func filterByPredicateType(predicateType string, attestations []*api.Attestation) ([]*api.Attestation, string, error) { + // Apply predicate type filter to returned attestations + filteredAttestations := verification.FilterAttestations(predicateType, attestations) + if len(filteredAttestations) == 0 { + msg := fmt.Sprintf("✗ No attestations found with predicate type: %s\n", predicateType) + return nil, msg, fmt.Errorf("no matching predicate found") + } + return filteredAttestations, "", nil +} + func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) { if o.FetchAttestationsFromGitHubAPI() { params := verification.FetchRemoteAttestationsParams{ @@ -33,18 +43,30 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio msg := fmt.Sprintf("✗ Loading attestations from %s failed", a.URL) return nil, msg, err } - pluralAttestation := text.Pluralize(len(attestations), "attestation") + + filtered, errMsg, err := filterByPredicateType(o.PredicateType, attestations) + if err != nil { + return nil, errMsg, err + } + + pluralAttestation := text.Pluralize(len(filtered), "attestation") msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) - return attestations, msg, nil + return filtered, msg, nil } else if o.UseBundleFromRegistry { attestations, err := verification.GetOCIAttestations(o.OCIClient, a) if err != nil { msg := "✗ Loading attestations from OCI registry failed" return nil, msg, err } - pluralAttestation := text.Pluralize(len(attestations), "attestation") + + filtered, errMsg, err := filterByPredicateType(o.PredicateType, attestations) + if err != nil { + return nil, errMsg, err + } + + pluralAttestation := text.Pluralize(len(filtered), "attestation") msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.ArtifactPath) - return attestations, msg, nil + return filtered, msg, nil } return nil, "", fmt.Errorf("no valid attestation source provided") diff --git a/pkg/cmd/attestation/verify/verify.go b/pkg/cmd/attestation/verify/verify.go index 65ae8ca3e..1de4172b4 100644 --- a/pkg/cmd/attestation/verify/verify.go +++ b/pkg/cmd/attestation/verify/verify.go @@ -235,14 +235,6 @@ func runVerify(opts *Options) error { // Print the message signifying success fetching attestations opts.Logger.Println(logMsg) - // Apply predicate type filter to returned attestations - filteredAttestations := verification.FilterAttestations(ec.PredicateType, attestations) - if len(filteredAttestations) == 0 { - opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ No attestations found with predicate type: %s\n"), opts.PredicateType) - return fmt.Errorf("no matching predicate found") - } - attestations = filteredAttestations - // print information about the policy that will be enforced against attestations opts.Logger.Println("\nThe following policy criteria will be enforced:") opts.Logger.Println(ec.BuildPolicyInformation())