diff --git a/pkg/cmd/attestation/api/client.go b/pkg/cmd/attestation/api/client.go index 460ae3aad..f47b5f759 100644 --- a/pkg/cmd/attestation/api/client.go +++ b/pkg/cmd/attestation/api/client.go @@ -1,11 +1,14 @@ package api import ( + "errors" "fmt" "io" "net/http" "strings" + "time" + "github.com/cenkalti/backoff/v4" "github.com/cli/cli/v2/api" ioconfig "github.com/cli/cli/v2/pkg/cmd/attestation/io" ) @@ -69,6 +72,9 @@ func (c *LiveClient) GetTrustDomain() (string, error) { return c.getTrustDomain(MetaPath) } +// Allow injecting backoff interval in tests. +var getAttestationRetryInterval = time.Millisecond * 200 + func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*Attestation, error) { c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) @@ -86,15 +92,31 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At var attestations []*Attestation var resp AttestationsResponse - var err error + bo := backoff.NewConstantBackOff(getAttestationRetryInterval) + // if no attestation or less than limit, then keep fetching for url != "" && len(attestations) < limit { - url, err = c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp) + err := backoff.Retry(func() error { + newURL, restErr := c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp) + + if restErr != nil { + if shouldRetry(restErr) { + return restErr + } else { + return backoff.Permanent(restErr) + } + } + + url = newURL + attestations = append(attestations, resp.Attestations...) + + return nil + }, backoff.WithMaxRetries(bo, 3)) + + // bail if RESTWithNext errored out if err != nil { return nil, err } - - attestations = append(attestations, resp.Attestations...) } if len(attestations) == 0 { @@ -108,10 +130,34 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At return attestations, nil } +func shouldRetry(err error) bool { + var httpError api.HTTPError + if errors.As(err, &httpError) { + if httpError.StatusCode >= 500 && httpError.StatusCode <= 599 { + return true + } + } + + return false +} + func (c *LiveClient) getTrustDomain(url string) (string, error) { var resp MetaResponse - err := c.api.REST(c.host, http.MethodGet, url, nil, &resp) + bo := backoff.NewConstantBackOff(getAttestationRetryInterval) + err := backoff.Retry(func() error { + restErr := c.api.REST(c.host, http.MethodGet, url, nil, &resp) + if restErr != nil { + if shouldRetry(restErr) { + return restErr + } else { + return backoff.Permanent(restErr) + } + } + + return nil + }, backoff.WithMaxRetries(bo, 3)) + if err != nil { return "", err } diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index bfcb40f5a..adac00598 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -204,3 +204,62 @@ func TestGetTrustDomain(t *testing.T) { }) } + +func TestGetAttestationsRetries(t *testing.T) { + getAttestationRetryInterval = 0 + + fetcher := mockDataGenerator{ + NumAttestations: 5, + } + + c := &LiveClient{ + api: mockAPIClient{ + OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(), + }, + logger: io.NewTestHandler(), + } + + attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + require.NoError(t, err) + + // assert the error path was executed; because this is a paged + // request, it should have errored twice + fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 2) + + // but we still successfully got the right data + require.Equal(t, len(attestations), 10) + bundle := (attestations)[0].Bundle + require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") + + // same test as above, but for GetByOwnerAndDigest: + attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) + require.NoError(t, err) + + // because we haven't reset the mock, we have added 2 more failed requests + fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 4) + + require.Equal(t, len(attestations), 10) + bundle = (attestations)[0].Bundle + require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") +} + +// test total retries +func TestGetAttestationsMaxRetries(t *testing.T) { + getAttestationRetryInterval = 0 + + fetcher := mockDataGenerator{ + NumAttestations: 5, + } + + c := &LiveClient{ + api: mockAPIClient{ + OnRESTWithNext: fetcher.OnREST500ErrorHandler(), + }, + logger: io.NewTestHandler(), + } + + _, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + require.Error(t, err) + + fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4) +} diff --git a/pkg/cmd/attestation/api/mock_apiClient_test.go b/pkg/cmd/attestation/api/mock_apiClient_test.go index d58cdbc79..e1654bd3f 100644 --- a/pkg/cmd/attestation/api/mock_apiClient_test.go +++ b/pkg/cmd/attestation/api/mock_apiClient_test.go @@ -6,6 +6,10 @@ import ( "fmt" "io" "strings" + + cliAPI "github.com/cli/cli/v2/api" + ghAPI "github.com/cli/go-gh/v2/pkg/api" + "github.com/stretchr/testify/mock" ) type mockAPIClient struct { @@ -22,14 +26,15 @@ func (m mockAPIClient) REST(hostname, method, p string, body io.Reader, data int } type mockDataGenerator struct { + mock.Mock NumAttestations int } -func (m mockDataGenerator) OnRESTSuccess(hostname, method, p string, body io.Reader, data interface{}) (string, error) { +func (m *mockDataGenerator) OnRESTSuccess(hostname, method, p string, body io.Reader, data interface{}) (string, error) { return m.OnRESTWithNextSuccessHelper(hostname, method, p, body, data, false) } -func (m mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string, body io.Reader, data interface{}) (string, error) { +func (m *mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string, body io.Reader, data interface{}) (string, error) { // if path doesn't contain after, it means first time hitting the mock server // so return the first page and return the link header in the response if !strings.Contains(p, "after") { @@ -40,7 +45,37 @@ func (m mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string, return m.OnRESTWithNextSuccessHelper(hostname, method, p, body, data, false) } -func (m mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p string, body io.Reader, data interface{}, hasNext bool) (string, error) { +// Returns a func that just calls OnRESTSuccessWithNextPage but half the time +// it returns a 500 error. +func (m *mockDataGenerator) FlakyOnRESTSuccessWithNextPageHandler() func(hostname, method, p string, body io.Reader, data interface{}) (string, error) { + // set up the flake counter + m.On("FlakyOnRESTSuccessWithNextPage:error").Return() + + count := 0 + return func(hostname, method, p string, body io.Reader, data interface{}) (string, error) { + if count%2 == 0 { + m.MethodCalled("FlakyOnRESTSuccessWithNextPage:error") + + count = count + 1 + return "", cliAPI.HTTPError{HTTPError: &ghAPI.HTTPError{StatusCode: 500}} + } else { + count = count + 1 + return m.OnRESTSuccessWithNextPage(hostname, method, p, body, data) + } + } +} + +// always returns a 500 +func (m *mockDataGenerator) OnREST500ErrorHandler() func(hostname, method, p string, body io.Reader, data interface{}) (string, error) { + m.On("OnREST500Error").Return() + return func(hostname, method, p string, body io.Reader, data interface{}) (string, error) { + m.MethodCalled("OnREST500Error") + + return "", cliAPI.HTTPError{HTTPError: &ghAPI.HTTPError{StatusCode: 500}} + } +} + +func (m *mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p string, body io.Reader, data interface{}, hasNext bool) (string, error) { atts := make([]*Attestation, m.NumAttestations) for j := 0; j < m.NumAttestations; j++ { att := makeTestAttestation() @@ -70,7 +105,7 @@ func (m mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p strin return "", nil } -func (m mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p string, body io.Reader, data interface{}) (string, error) { +func (m *mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p string, body io.Reader, data interface{}) (string, error) { resp := AttestationsResponse{ Attestations: make([]*Attestation, 0), } @@ -89,7 +124,7 @@ func (m mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p stri return "", nil } -func (m mockDataGenerator) OnRESTWithNextError(hostname, method, p string, body io.Reader, data interface{}) (string, error) { +func (m *mockDataGenerator) OnRESTWithNextError(hostname, method, p string, body io.Reader, data interface{}) (string, error) { return "", errors.New("failed to get attestations") }