diff --git a/pkg/cmd/attestation/api/attestation.go b/pkg/cmd/attestation/api/attestation.go index 16b062a96..fd6b484a7 100644 --- a/pkg/cmd/attestation/api/attestation.go +++ b/pkg/cmd/attestation/api/attestation.go @@ -1,8 +1,7 @@ package api import ( - "fmt" - + "errors" "github.com/sigstore/sigstore-go/pkg/bundle" ) @@ -11,18 +10,7 @@ const ( GetAttestationByOwnerAndSubjectDigestPath = "orgs/%s/attestations/%s" ) -type ErrNoAttestations struct { - name string - digest string -} - -func (e ErrNoAttestations) Error() string { - return fmt.Sprintf("no attestations found for digest %s in %s", e.name, e.digest) -} - -func newErrNoAttestations(name, digest string) ErrNoAttestations { - return ErrNoAttestations{name, digest} -} +var ErrNoAttestationsFound = errors.New("no attestations found") type Attestation struct { Bundle *bundle.Bundle `json:"bundle"` diff --git a/pkg/cmd/attestation/api/client.go b/pkg/cmd/attestation/api/client.go index 6054bc98e..1e99a2a06 100644 --- a/pkg/cmd/attestation/api/client.go +++ b/pkg/cmd/attestation/api/client.go @@ -60,44 +60,26 @@ func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClien } } -func (c *LiveClient) BuildRepoAndDigestURL(repo, digest string) string { - repo = strings.Trim(repo, "/") - return fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, repo, digest) -} - // GetByRepoAndDigest fetches the attestation by repo and digest func (c *LiveClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) { - url := c.BuildRepoAndDigestURL(repo, digest) - attestations, err := c.getAttestations(url, repo, digest, limit) - if err != nil { - return nil, err - } - - bundles, err := c.fetchBundleFromAttestations(attestations) - if err != nil { - return nil, fmt.Errorf("failed to fetch bundle with URL: %w", err) - } - - return bundles, nil -} - -func (c *LiveClient) BuildOwnerAndDigestURL(owner, digest string) string { - owner = strings.Trim(owner, "/") - return fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, owner, digest) + c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) + url := fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, repo, digest) + return c.getByURL(url, limit) } // GetByOwnerAndDigest fetches attestation by owner and digest func (c *LiveClient) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) { - url := c.BuildOwnerAndDigestURL(owner, digest) - attestations, err := c.getAttestations(url, owner, digest, limit) + c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) + url := fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, owner, digest) + return c.getByURL(url, limit) +} + +func (c *LiveClient) getByURL(url string, limit int) ([]*Attestation, error) { + attestations, err := c.getAttestations(url, limit) if err != nil { return nil, err } - if len(attestations) == 0 { - return nil, newErrNoAttestations(owner, digest) - } - bundles, err := c.fetchBundleFromAttestations(attestations) if err != nil { return nil, fmt.Errorf("failed to fetch bundle with URL: %w", err) @@ -112,9 +94,7 @@ func (c *LiveClient) GetTrustDomain() (string, error) { return c.getTrustDomain(MetaPath) } -func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*Attestation, error) { - c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) - +func (c *LiveClient) getAttestations(url string, limit int) ([]*Attestation, error) { perPage := limit if perPage <= 0 || perPage > maxLimitForFlag { return nil, fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag) @@ -157,7 +137,7 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At } if len(attestations) == 0 { - return nil, newErrNoAttestations(name, digest) + return nil, ErrNoAttestationsFound } if len(attestations) > limit { @@ -176,23 +156,22 @@ func (c *LiveClient) fetchBundleFromAttestations(attestations []*Attestation) ([ return fmt.Errorf("attestation has no bundle or bundle URL") } - // If the bundle field is nil, try to fetch the bundle with the provided URL - if a.Bundle == nil { - c.logger.VerbosePrintf("Bundle field is empty. Trying to fetch with bundle URL\n\n") - b, err := c.GetBundle(a.BundleURL) - if err != nil { - return fmt.Errorf("failed to fetch bundle with URL: %w", err) - } + // for now, we fall back to the bundle field if the bundle URL is empty + if a.BundleURL == "" { + c.logger.VerbosePrintf("Bundle URL is empty. Falling back to bundle field\n\n") fetched[i] = &Attestation{ - Bundle: b, + Bundle: a.Bundle, } return nil } - // otherwise fall back to the bundle field - c.logger.VerbosePrintf("Fetching bundle from Bundle field\n\n") + // otherwise fetch the bundle with the provided URL + b, err := c.getBundle(a.BundleURL) + if err != nil { + return fmt.Errorf("failed to fetch bundle with URL: %w", err) + } fetched[i] = &Attestation{ - Bundle: a.Bundle, + Bundle: b, } return nil @@ -206,38 +185,49 @@ func (c *LiveClient) fetchBundleFromAttestations(attestations []*Attestation) ([ return fetched, nil } -func (c *LiveClient) GetBundle(url string) (*bundle.Bundle, error) { +func (c *LiveClient) getBundle(url string) (*bundle.Bundle, error) { c.logger.VerbosePrintf("Fetching attestation bundle with bundle URL\n\n") - resp, err := c.httpClient.Get(url) - if err != nil { - return nil, err - } + var sgBundle *bundle.Bundle + bo := backoff.NewConstantBackOff(getAttestationRetryInterval) + err := backoff.Retry(func() error { + resp, err := c.httpClient.Get(url) + if err != nil { + return fmt.Errorf("request to fetch bundle from URL failed: %w", err) + } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } + if resp.StatusCode >= 500 && resp.StatusCode <= 599 { + return fmt.Errorf("attestation bundle with URL %s returned status code %d", url, resp.StatusCode) + } - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read blob storage response body: %w", err) - } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read blob storage response body: %w", err) + } - var out []byte - decompressed, err := snappy.Decode(out, body) - if err != nil { - return nil, fmt.Errorf("failed to decompress with snappy: %w", err) - } + var out []byte + decompressed, err := snappy.Decode(out, body) + if err != nil { + return backoff.Permanent(fmt.Errorf("failed to decompress with snappy: %w", err)) + } - var pbBundle v1.Bundle - if err = protojson.Unmarshal(decompressed, &pbBundle); err != nil { - return nil, fmt.Errorf("failed to unmarshal to bundle: %w", err) - } + var pbBundle v1.Bundle + if err = protojson.Unmarshal(decompressed, &pbBundle); err != nil { + return backoff.Permanent(fmt.Errorf("failed to unmarshal to bundle: %w", err)) + } - c.logger.VerbosePrintf("Successfully fetched bundle\n\n") + c.logger.VerbosePrintf("Successfully fetched bundle\n\n") - return bundle.NewBundle(&pbBundle) + sgBundle, err = bundle.NewBundle(&pbBundle) + if err != nil { + return backoff.Permanent(fmt.Errorf("failed to create new bundle: %w", err)) + } + + return nil + }, backoff.WithMaxRetries(bo, 3)) + + return sgBundle, err } func shouldRetry(err error) bool { diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index 3d180af8f..787408a4e 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -42,24 +42,6 @@ func NewClientWithMockGHClient(hasNextPage bool) Client { } } -func TestGetURL(t *testing.T) { - c := LiveClient{} - - testData := []struct { - repo string - digest string - expected string - }{ - {repo: "/github/example/", digest: "sha256:12313213", expected: "repos/github/example/attestations/sha256:12313213"}, - {repo: "/github/example", digest: "sha256:12313213", expected: "repos/github/example/attestations/sha256:12313213"}, - } - - for _, data := range testData { - s := c.BuildRepoAndDigestURL(data.repo, data.digest) - require.Equal(t, data.expected, s) - } -} - func TestGetByDigest(t *testing.T) { c := NewClientWithMockGHClient(false) attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) @@ -150,12 +132,12 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) { attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) require.Error(t, err) - require.IsType(t, ErrNoAttestations{}, err) + require.IsType(t, ErrNoAttestationsFound, err) require.Nil(t, attestations) attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) require.Error(t, err) - require.IsType(t, ErrNoAttestations{}, err) + require.IsType(t, ErrNoAttestationsFound, err) require.Nil(t, attestations) } @@ -188,10 +170,7 @@ func TestFetchBundleFromAttestations_BundleURL(t *testing.T) { } att1 := makeTestAttestation() - att1.Bundle = nil att2 := makeTestAttestation() - att2.Bundle = nil - // zero out the bundle field so it tries fetching by URL attestations := []*Attestation{&att1, &att2} fetched, err := client.fetchBundleFromAttestations(attestations) require.NoError(t, err) @@ -200,42 +179,46 @@ func TestFetchBundleFromAttestations_BundleURL(t *testing.T) { httpClient.AssertNumberOfCalls(t, "OnGetSuccess", 2) } -func TestFetchBundleFromAttestations_InvalidAttestation(t *testing.T) { +func TestFetchBundleFromAttestations_MissingBundleAndBundleURLFields(t *testing.T) { httpClient := &mockHttpClient{} client := LiveClient{ httpClient: httpClient, logger: io.NewTestHandler(), } + // If both the BundleURL and Bundle fields are empty, the function should + // return an error indicating that att1 := Attestation{} attestations := []*Attestation{&att1} - fetched, err := client.fetchBundleFromAttestations(attestations) - require.Error(t, err) - require.Nil(t, fetched, 2) + bundles, err := client.fetchBundleFromAttestations(attestations) + require.ErrorContains(t, err, "attestation has no bundle or bundle URL") + require.Nil(t, bundles, 2) } -func TestFetchBundleFromAttestations_Fail_BundleURL(t *testing.T) { - httpClient := &failAfterOneCallHttpClient{} +func TestFetchBundleFromAttestations_FailOnTheSecondAttestation(t *testing.T) { + mockHTTPClient := &failAfterNCallsHttpClient{ + // the initial HTTP request will succeed, which returns a bundle for the first attestation + // all following HTTP requests will fail, which means the function fails to fetch a bundle + // for the second attestation and the function returns an error + FailOnCallN: 2, + FailOnAllSubsequentCalls: true, + } c := &LiveClient{ - httpClient: httpClient, + httpClient: mockHTTPClient, logger: io.NewTestHandler(), } att1 := makeTestAttestation() - att1.Bundle = nil att2 := makeTestAttestation() - att2.Bundle = nil - // zero out the bundle field so it tries fetching by URL attestations := []*Attestation{&att1, &att2} - fetched, err := c.fetchBundleFromAttestations(attestations) + bundles, err := c.fetchBundleFromAttestations(attestations) require.Error(t, err) - require.Nil(t, fetched) - httpClient.AssertNumberOfCalls(t, "OnGetFailAfterOneCall", 2) + require.Nil(t, bundles) } -func TestFetchBundleFromAttestations_FetchByURLFail(t *testing.T) { - mockHTTPClient := &failHttpClient{} +func TestFetchBundleFromAttestations_FailAfterRetrying(t *testing.T) { + mockHTTPClient := &reqFailHttpClient{} c := &LiveClient{ httpClient: mockHTTPClient, @@ -243,15 +226,14 @@ func TestFetchBundleFromAttestations_FetchByURLFail(t *testing.T) { } a := makeTestAttestation() - a.Bundle = nil attestations := []*Attestation{&a} bundle, err := c.fetchBundleFromAttestations(attestations) require.Error(t, err) require.Nil(t, bundle) - mockHTTPClient.AssertNumberOfCalls(t, "OnGetFail", 1) + mockHTTPClient.AssertNumberOfCalls(t, "OnGetReqFail", 4) } -func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) { +func TestFetchBundleFromAttestations_FallbackToBundleField(t *testing.T) { mockHTTPClient := &mockHttpClient{} c := &LiveClient{ @@ -259,6 +241,7 @@ func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) { logger: io.NewTestHandler(), } + // If the bundle URL is empty, the code will fallback to the bundle field a := Attestation{Bundle: data.SigstoreBundle(t)} attestations := []*Attestation{&a} fetched, err := c.fetchBundleFromAttestations(attestations) @@ -267,6 +250,71 @@ func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) { mockHTTPClient.AssertNotCalled(t, "OnGetSuccess") } +// getBundle successfully fetches a bundle on the first HTTP request attempt +func TestGetBundle(t *testing.T) { + mockHTTPClient := &mockHttpClient{} + + c := &LiveClient{ + httpClient: mockHTTPClient, + logger: io.NewTestHandler(), + } + + b, err := c.getBundle("https://mybundleurl.com") + require.NoError(t, err) + require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", b.GetMediaType()) + mockHTTPClient.AssertNumberOfCalls(t, "OnGetSuccess", 1) +} + +// getBundle retries successfully when the initial HTTP request returns +// a 5XX status code +func TestGetBundle_SuccessfulRetry(t *testing.T) { + mockHTTPClient := &failAfterNCallsHttpClient{ + FailOnCallN: 1, + FailOnAllSubsequentCalls: false, + } + + c := &LiveClient{ + httpClient: mockHTTPClient, + logger: io.NewTestHandler(), + } + + b, err := c.getBundle("mybundleurl") + require.NoError(t, err) + require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", b.GetMediaType()) + mockHTTPClient.AssertNumberOfCalls(t, "OnGetFailAfterNCalls", 2) +} + +// getBundle does not retry when the function fails with a permanent backoff error condition +func TestGetBundle_PermanentBackoffFail(t *testing.T) { + mockHTTPClient := &invalidBundleClient{} + c := &LiveClient{ + httpClient: mockHTTPClient, + logger: io.NewTestHandler(), + } + + b, err := c.getBundle("mybundleurl") + // var permanent *backoff.PermanentError + //require.IsType(t, &backoff.PermanentError{}, err) + require.Error(t, err) + require.Nil(t, b) + mockHTTPClient.AssertNumberOfCalls(t, "OnGetInvalidBundle", 1) +} + +// getBundle retries when the HTTP request fails +func TestGetBundle_RequestFail(t *testing.T) { + mockHTTPClient := &reqFailHttpClient{} + + c := &LiveClient{ + httpClient: mockHTTPClient, + logger: io.NewTestHandler(), + } + + b, err := c.getBundle("mybundleurl") + require.Error(t, err) + require.Nil(t, b) + mockHTTPClient.AssertNumberOfCalls(t, "OnGetReqFail", 4) +} + func TestGetTrustDomain(t *testing.T) { fetcher := mockMetaGenerator{ TrustDomain: "foo", diff --git a/pkg/cmd/attestation/api/mock_httpClient_test.go b/pkg/cmd/attestation/api/mock_httpClient_test.go index 4dc34fbf8..26933ae2e 100644 --- a/pkg/cmd/attestation/api/mock_httpClient_test.go +++ b/pkg/cmd/attestation/api/mock_httpClient_test.go @@ -27,34 +27,55 @@ func (m *mockHttpClient) Get(url string) (*http.Response, error) { }, nil } -type failHttpClient struct { +type invalidBundleClient struct { mock.Mock } -func (m *failHttpClient) Get(url string) (*http.Response, error) { - m.On("OnGetFail").Return() - m.MethodCalled("OnGetFail") +func (m *invalidBundleClient) Get(url string) (*http.Response, error) { + m.On("OnGetInvalidBundle").Return() + m.MethodCalled("OnGetInvalidBundle") + + var compressed []byte + compressed = snappy.Encode(compressed, []byte("invalid bundle bytes")) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(compressed)), + }, nil +} + +type reqFailHttpClient struct { + mock.Mock +} + +func (m *reqFailHttpClient) Get(url string) (*http.Response, error) { + m.On("OnGetReqFail").Return() + m.MethodCalled("OnGetReqFail") return &http.Response{ StatusCode: 500, }, fmt.Errorf("failed to fetch with %s", url) } -type failAfterOneCallHttpClient struct { +type failAfterNCallsHttpClient struct { mock.Mock + FailOnCallN int + FailOnAllSubsequentCalls bool + NumCalls int } -func (m *failAfterOneCallHttpClient) Get(url string) (*http.Response, error) { - m.On("OnGetFailAfterOneCall").Return() +func (m *failAfterNCallsHttpClient) Get(url string) (*http.Response, error) { + m.On("OnGetFailAfterNCalls").Return() - if len(m.Calls) >= 1 { - m.MethodCalled("OnGetFailAfterOneCall") + m.NumCalls++ + + if m.NumCalls == m.FailOnCallN || (m.NumCalls > m.FailOnCallN && m.FailOnAllSubsequentCalls) { + m.MethodCalled("OnGetFailAfterNCalls") return &http.Response{ StatusCode: 500, - }, fmt.Errorf("failed to fetch with %s", url) + }, nil } - m.MethodCalled("OnGetFailAfterOneCall") + m.MethodCalled("OnGetFailAfterNCalls") var compressed []byte compressed = snappy.Encode(compressed, data.SigstoreBundleRaw) return &http.Response{ diff --git a/pkg/cmd/attestation/download/download.go b/pkg/cmd/attestation/download/download.go index 7547e9e68..cdbdc0078 100644 --- a/pkg/cmd/attestation/download/download.go +++ b/pkg/cmd/attestation/download/download.go @@ -135,7 +135,7 @@ func runDownload(opts *Options) error { } attestations, err := verification.GetRemoteAttestations(opts.APIClient, params) if err != nil { - if errors.Is(err, api.ErrNoAttestations{}) { + if errors.Is(err, api.ErrNoAttestationsFound) { fmt.Fprintf(opts.Logger.IO.Out, "No attestations found for %s\n", opts.ArtifactPath) return nil } diff --git a/pkg/cmd/attestation/download/download_test.go b/pkg/cmd/attestation/download/download_test.go index 6c2986065..ddcd08c92 100644 --- a/pkg/cmd/attestation/download/download_test.go +++ b/pkg/cmd/attestation/download/download_test.go @@ -276,7 +276,7 @@ func TestRunDownload(t *testing.T) { opts := baseOpts opts.APIClient = api.MockClient{ OnGetByOwnerAndDigest: func(repo, digest string, limit int) ([]*api.Attestation, error) { - return nil, api.ErrNoAttestations{} + return nil, api.ErrNoAttestationsFound }, } diff --git a/pkg/cmd/attestation/verify/verify.go b/pkg/cmd/attestation/verify/verify.go index ea7502f00..90242a9fe 100644 --- a/pkg/cmd/attestation/verify/verify.go +++ b/pkg/cmd/attestation/verify/verify.go @@ -221,7 +221,7 @@ func runVerify(opts *Options) error { attestations, logMsg, err := getAttestations(opts, *artifact) if err != nil { - if ok := errors.Is(err, api.ErrNoAttestations{}); ok { + if ok := errors.Is(err, api.ErrNoAttestationsFound); ok { opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ No attestations found for subject %s\n"), artifact.DigestWithAlg()) return err }