diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index 7c93817be..b131b7acd 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -115,7 +115,7 @@ func GetRemoteAttestations(c FetchAttestationsConfig) ([]*api.Attestation, error return nil, fmt.Errorf("owner or repo must be provided") } -type DssePayload struct { +type IntotoStatement struct { PredicateType string `json:"predicateType"` } @@ -125,12 +125,16 @@ func FilterAttestations(predicateType string, attestations []*api.Attestation) [ for _, each := range attestations { dsseEnvelope := each.Bundle.GetDsseEnvelope() if dsseEnvelope != nil { - var dssePayload DssePayload - if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &dssePayload); err != nil { + if dsseEnvelope.PayloadType != "application/vnd.in-toto+json" { + // Don't fail just because an entry isn't intoto + continue + } + var intotoStatement IntotoStatement + if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &intotoStatement); err != nil { // Don't fail just because a single entry can't be unmarshalled continue } - if dssePayload.PredicateType == predicateType { + if intotoStatement.PredicateType == predicateType { filteredAttestations = append(filteredAttestations, each) } } diff --git a/pkg/cmd/attestation/verification/attestation_test.go b/pkg/cmd/attestation/verification/attestation_test.go index b178a0682..ef7c7d879 100644 --- a/pkg/cmd/attestation/verification/attestation_test.go +++ b/pkg/cmd/attestation/verification/attestation_test.go @@ -3,7 +3,12 @@ package verification import ( "testing" + protobundle "github.com/sigstore/protobuf-specs/gen/pb-go/bundle/v1" + dsse "github.com/sigstore/protobuf-specs/gen/pb-go/dsse" + "github.com/sigstore/sigstore-go/pkg/bundle" "github.com/stretchr/testify/require" + + "github.com/cli/cli/v2/pkg/cmd/attestation/api" ) func TestLoadBundlesFromJSONLinesFile(t *testing.T) { @@ -47,3 +52,51 @@ func TestGetLocalAttestations(t *testing.T) { require.Nil(t, attestations) }) } + +func TestFilterAttestations(t *testing.T) { + attestations := []*api.Attestation{ + { + Bundle: &bundle.ProtobufBundle{ + Bundle: &protobundle.Bundle{ + Content: &protobundle.Bundle_DsseEnvelope{ + DsseEnvelope: &dsse.Envelope{ + PayloadType: "application/vnd.in-toto+json", + Payload: []byte("{\"predicateType\": \"https://slsa.dev/provenance/v1\"}"), + }, + }, + }, + }, + }, + { + Bundle: &bundle.ProtobufBundle{ + Bundle: &protobundle.Bundle{ + Content: &protobundle.Bundle_DsseEnvelope{ + DsseEnvelope: &dsse.Envelope{ + PayloadType: "application/vnd.something-other-than-in-toto+json", + Payload: []byte("{\"predicateType\": \"https://slsa.dev/provenance/v1\"}"), + }, + }, + }, + }, + }, + { + Bundle: &bundle.ProtobufBundle{ + Bundle: &protobundle.Bundle{ + Content: &protobundle.Bundle_DsseEnvelope{ + DsseEnvelope: &dsse.Envelope{ + PayloadType: "application/vnd.in-toto+json", + Payload: []byte("{\"predicateType\": \"https://spdx.dev/Document/v2.3\"}"), + }, + }, + }, + }, + }, + } + + filtered := FilterAttestations("https://slsa.dev/provenance/v1", attestations) + + require.Len(t, filtered, 1) + + filtered = FilterAttestations("NonExistantPredicate", attestations) + require.Len(t, filtered, 0) +}