diff --git a/pkg/cmd/attestation/api/attestation.go b/pkg/cmd/attestation/api/attestation.go index fd6b484a7..daec12b50 100644 --- a/pkg/cmd/attestation/api/attestation.go +++ b/pkg/cmd/attestation/api/attestation.go @@ -1,7 +1,10 @@ package api import ( + "encoding/json" "errors" + "fmt" + "github.com/sigstore/sigstore-go/pkg/bundle" ) @@ -20,3 +23,35 @@ type Attestation struct { type AttestationsResponse struct { Attestations []*Attestation `json:"attestations"` } + +type IntotoStatement struct { + PredicateType string `json:"predicateType"` +} + +func FilterAttestations(predicateType string, attestations []*Attestation) ([]*Attestation, error) { + filteredAttestations := []*Attestation{} + + for _, each := range attestations { + dsseEnvelope := each.Bundle.GetDsseEnvelope() + if dsseEnvelope != nil { + if dsseEnvelope.PayloadType != "application/vnd.in-toto+json" { + // Don't fail just because an entry isn't intoto + continue + } + var intotoStatement IntotoStatement + if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &intotoStatement); err != nil { + // Don't fail just because a single entry can't be unmarshalled + continue + } + if intotoStatement.PredicateType == predicateType { + filteredAttestations = append(filteredAttestations, each) + } + } + } + + if len(filteredAttestations) == 0 { + return nil, fmt.Errorf("no attestations found with predicate type: %s", predicateType) + } + + return filteredAttestations, nil +} diff --git a/pkg/cmd/attestation/api/mock_client.go b/pkg/cmd/attestation/api/mock_client.go index fbf44bbb9..b6062b39f 100644 --- a/pkg/cmd/attestation/api/mock_client.go +++ b/pkg/cmd/attestation/api/mock_client.go @@ -26,7 +26,12 @@ func (m MockClient) GetTrustDomain() (string, error) { func OnGetByDigestSuccess(params FetchParams) ([]*Attestation, error) { att1 := makeTestAttestation() att2 := makeTestAttestation() - return []*Attestation{&att1, &att2}, nil + attestations := []*Attestation{&att1, &att2} + if params.PredicateType != "" { + return FilterAttestations(params.PredicateType, attestations) + } + + return attestations, nil } func OnGetByDigestFailure(params FetchParams) ([]*Attestation, error) { diff --git a/pkg/cmd/attestation/download/download.go b/pkg/cmd/attestation/download/download.go index 2cb648414..8d1d1dc05 100644 --- a/pkg/cmd/attestation/download/download.go +++ b/pkg/cmd/attestation/download/download.go @@ -9,7 +9,6 @@ import ( "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" "github.com/cli/cli/v2/pkg/cmd/attestation/auth" "github.com/cli/cli/v2/pkg/cmd/attestation/io" - "github.com/cli/cli/v2/pkg/cmd/attestation/verification" "github.com/cli/cli/v2/pkg/cmdutil" ghauth "github.com/cli/go-gh/v2/pkg/auth" @@ -147,7 +146,7 @@ func runDownload(opts *Options) error { // Apply predicate type filter to returned attestations if opts.PredicateType != "" { - filteredAttestations, err := verification.FilterAttestations(opts.PredicateType, attestations) + filteredAttestations, err := api.FilterAttestations(opts.PredicateType, attestations) if err != nil { return fmt.Errorf("failed to filter attestations: %v", err) } diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index ba357a5cc..10eb02ac4 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -92,35 +92,3 @@ func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ( } return attestations, nil } - -type IntotoStatement struct { - PredicateType string `json:"predicateType"` -} - -func FilterAttestations(predicateType string, attestations []*api.Attestation) ([]*api.Attestation, error) { - filteredAttestations := []*api.Attestation{} - - for _, each := range attestations { - dsseEnvelope := each.Bundle.GetDsseEnvelope() - if dsseEnvelope != nil { - if dsseEnvelope.PayloadType != "application/vnd.in-toto+json" { - // Don't fail just because an entry isn't intoto - continue - } - var intotoStatement IntotoStatement - if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &intotoStatement); err != nil { - // Don't fail just because a single entry can't be unmarshalled - continue - } - if intotoStatement.PredicateType == predicateType { - filteredAttestations = append(filteredAttestations, each) - } - } - } - - if len(filteredAttestations) == 0 { - return nil, fmt.Errorf("no attestations found with predicate type: %s", predicateType) - } - - return filteredAttestations, nil -} diff --git a/pkg/cmd/attestation/verification/attestation_test.go b/pkg/cmd/attestation/verification/attestation_test.go index 18e2c6cca..6826e2e40 100644 --- a/pkg/cmd/attestation/verification/attestation_test.go +++ b/pkg/cmd/attestation/verification/attestation_test.go @@ -157,11 +157,11 @@ func TestFilterAttestations(t *testing.T) { }, } - filtered, err := FilterAttestations("https://slsa.dev/provenance/v1", attestations) + filtered, err := api.FilterAttestations("https://slsa.dev/provenance/v1", attestations) require.Len(t, filtered, 1) require.NoError(t, err) - filtered, err = FilterAttestations("NonExistentPredicate", attestations) + filtered, err = api.FilterAttestations("NonExistentPredicate", attestations) require.Nil(t, filtered) require.Error(t, err) } diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index 68d38f985..1b98fabf3 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -63,7 +63,7 @@ func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestatio return nil, msg, err } - filtered, err := verification.FilterAttestations(o.PredicateType, attestations) + filtered, err := api.FilterAttestations(o.PredicateType, attestations) if err != nil { return nil, err.Error(), err } diff --git a/pkg/cmd/attestation/verify/attestation_test.go b/pkg/cmd/attestation/verify/attestation_test.go index 93ac7d327..f015805ae 100644 --- a/pkg/cmd/attestation/verify/attestation_test.go +++ b/pkg/cmd/attestation/verify/attestation_test.go @@ -3,6 +3,7 @@ package verify import ( "testing" + "github.com/cli/cli/v2/pkg/cmd/attestation/api" "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" "github.com/cli/cli/v2/pkg/cmd/attestation/verification" @@ -40,14 +41,31 @@ func TestGetAttestations_LocalBundle_PredicateTypeFiltering(t *testing.T) { PredicateType: verification.SLSAPredicateV1, Repo: "sigstore/sigstore-js", } - attestations, msg, err := getAttestations(o, *artifact) + attestations, _, err := getAttestations(o, *artifact) require.NoError(t, err) - require.Contains(t, msg, "Loaded 1 attestation from ../test/data/gh_2.60.1_windows_arm64.zip") require.Len(t, attestations, 1) o.PredicateType = "custom predicate type" - attestations, msg, err = getAttestations(o, *artifact) + attestations, _, err = getAttestations(o, *artifact) + require.Error(t, err) + require.Nil(t, attestations) +} + +func TestGetAttestations_GhAPI_NoAttestationsFound(t *testing.T) { + artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256") + require.NoError(t, err) + + o := &Options{ + APIClient: api.NewTestClient(), + PredicateType: verification.SLSAPredicateV1, + Repo: "sigstore/sigstore-js", + } + attestations, _, err := getAttestations(o, *artifact) + require.NoError(t, err) + require.Len(t, attestations, 2) + + o.PredicateType = "custom predicate type" + attestations, _, err = getAttestations(o, *artifact) require.Error(t, err) - require.Contains(t, msg, "no attestations found with predicate type") require.Nil(t, attestations) }