Merge pull request #9797 from cli/phillmv/retry-getting-attestations

`gh at verify` retries fetching attestations if it receives a 5xx
This commit is contained in:
Phill MV 2024-10-23 13:45:09 -04:00 committed by GitHub
commit afa4272bdf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 150 additions and 10 deletions

View file

@ -1,11 +1,14 @@
package api
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/cli/cli/v2/api"
ioconfig "github.com/cli/cli/v2/pkg/cmd/attestation/io"
)
@ -69,6 +72,9 @@ func (c *LiveClient) GetTrustDomain() (string, error) {
return c.getTrustDomain(MetaPath)
}
// Allow injecting backoff interval in tests.
var getAttestationRetryInterval = time.Millisecond * 200
func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*Attestation, error) {
c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest)
@ -86,15 +92,31 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At
var attestations []*Attestation
var resp AttestationsResponse
var err error
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)
// if no attestation or less than limit, then keep fetching
for url != "" && len(attestations) < limit {
url, err = c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
err := backoff.Retry(func() error {
newURL, restErr := c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
if restErr != nil {
if shouldRetry(restErr) {
return restErr
} else {
return backoff.Permanent(restErr)
}
}
url = newURL
attestations = append(attestations, resp.Attestations...)
return nil
}, backoff.WithMaxRetries(bo, 3))
// bail if RESTWithNext errored out
if err != nil {
return nil, err
}
attestations = append(attestations, resp.Attestations...)
}
if len(attestations) == 0 {
@ -108,10 +130,34 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At
return attestations, nil
}
func shouldRetry(err error) bool {
var httpError api.HTTPError
if errors.As(err, &httpError) {
if httpError.StatusCode >= 500 && httpError.StatusCode <= 599 {
return true
}
}
return false
}
func (c *LiveClient) getTrustDomain(url string) (string, error) {
var resp MetaResponse
err := c.api.REST(c.host, http.MethodGet, url, nil, &resp)
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)
err := backoff.Retry(func() error {
restErr := c.api.REST(c.host, http.MethodGet, url, nil, &resp)
if restErr != nil {
if shouldRetry(restErr) {
return restErr
} else {
return backoff.Permanent(restErr)
}
}
return nil
}, backoff.WithMaxRetries(bo, 3))
if err != nil {
return "", err
}

View file

@ -204,3 +204,62 @@ func TestGetTrustDomain(t *testing.T) {
})
}
func TestGetAttestationsRetries(t *testing.T) {
getAttestationRetryInterval = 0
fetcher := mockDataGenerator{
NumAttestations: 5,
}
c := &LiveClient{
api: mockAPIClient{
OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(),
},
logger: io.NewTestHandler(),
}
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
require.NoError(t, err)
// assert the error path was executed; because this is a paged
// request, it should have errored twice
fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 2)
// but we still successfully got the right data
require.Equal(t, len(attestations), 10)
bundle := (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
// same test as above, but for GetByOwnerAndDigest:
attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit)
require.NoError(t, err)
// because we haven't reset the mock, we have added 2 more failed requests
fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 4)
require.Equal(t, len(attestations), 10)
bundle = (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
}
// test total retries
func TestGetAttestationsMaxRetries(t *testing.T) {
getAttestationRetryInterval = 0
fetcher := mockDataGenerator{
NumAttestations: 5,
}
c := &LiveClient{
api: mockAPIClient{
OnRESTWithNext: fetcher.OnREST500ErrorHandler(),
},
logger: io.NewTestHandler(),
}
_, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
require.Error(t, err)
fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4)
}

View file

@ -6,6 +6,10 @@ import (
"fmt"
"io"
"strings"
cliAPI "github.com/cli/cli/v2/api"
ghAPI "github.com/cli/go-gh/v2/pkg/api"
"github.com/stretchr/testify/mock"
)
type mockAPIClient struct {
@ -22,14 +26,15 @@ func (m mockAPIClient) REST(hostname, method, p string, body io.Reader, data int
}
type mockDataGenerator struct {
mock.Mock
NumAttestations int
}
func (m mockDataGenerator) OnRESTSuccess(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTSuccess(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
return m.OnRESTWithNextSuccessHelper(hostname, method, p, body, data, false)
}
func (m mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
// if path doesn't contain after, it means first time hitting the mock server
// so return the first page and return the link header in the response
if !strings.Contains(p, "after") {
@ -40,7 +45,37 @@ func (m mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string,
return m.OnRESTWithNextSuccessHelper(hostname, method, p, body, data, false)
}
func (m mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p string, body io.Reader, data interface{}, hasNext bool) (string, error) {
// Returns a func that just calls OnRESTSuccessWithNextPage but half the time
// it returns a 500 error.
func (m *mockDataGenerator) FlakyOnRESTSuccessWithNextPageHandler() func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
// set up the flake counter
m.On("FlakyOnRESTSuccessWithNextPage:error").Return()
count := 0
return func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
if count%2 == 0 {
m.MethodCalled("FlakyOnRESTSuccessWithNextPage:error")
count = count + 1
return "", cliAPI.HTTPError{HTTPError: &ghAPI.HTTPError{StatusCode: 500}}
} else {
count = count + 1
return m.OnRESTSuccessWithNextPage(hostname, method, p, body, data)
}
}
}
// always returns a 500
func (m *mockDataGenerator) OnREST500ErrorHandler() func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
m.On("OnREST500Error").Return()
return func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
m.MethodCalled("OnREST500Error")
return "", cliAPI.HTTPError{HTTPError: &ghAPI.HTTPError{StatusCode: 500}}
}
}
func (m *mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p string, body io.Reader, data interface{}, hasNext bool) (string, error) {
atts := make([]*Attestation, m.NumAttestations)
for j := 0; j < m.NumAttestations; j++ {
att := makeTestAttestation()
@ -70,7 +105,7 @@ func (m mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p strin
return "", nil
}
func (m mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
resp := AttestationsResponse{
Attestations: make([]*Attestation, 0),
}
@ -89,7 +124,7 @@ func (m mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p stri
return "", nil
}
func (m mockDataGenerator) OnRESTWithNextError(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTWithNextError(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
return "", errors.New("failed to get attestations")
}