Merge pull request #9797 from cli/phillmv/retry-getting-attestations
`gh at verify` retries fetching attestations if it receives a 5xx
This commit is contained in:
commit
afa4272bdf
3 changed files with 150 additions and 10 deletions
|
|
@ -1,11 +1,14 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/cli/cli/v2/api"
|
||||
ioconfig "github.com/cli/cli/v2/pkg/cmd/attestation/io"
|
||||
)
|
||||
|
|
@ -69,6 +72,9 @@ func (c *LiveClient) GetTrustDomain() (string, error) {
|
|||
return c.getTrustDomain(MetaPath)
|
||||
}
|
||||
|
||||
// Allow injecting backoff interval in tests.
|
||||
var getAttestationRetryInterval = time.Millisecond * 200
|
||||
|
||||
func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*Attestation, error) {
|
||||
c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest)
|
||||
|
||||
|
|
@ -86,15 +92,31 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At
|
|||
|
||||
var attestations []*Attestation
|
||||
var resp AttestationsResponse
|
||||
var err error
|
||||
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)
|
||||
|
||||
// if no attestation or less than limit, then keep fetching
|
||||
for url != "" && len(attestations) < limit {
|
||||
url, err = c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
|
||||
err := backoff.Retry(func() error {
|
||||
newURL, restErr := c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
|
||||
|
||||
if restErr != nil {
|
||||
if shouldRetry(restErr) {
|
||||
return restErr
|
||||
} else {
|
||||
return backoff.Permanent(restErr)
|
||||
}
|
||||
}
|
||||
|
||||
url = newURL
|
||||
attestations = append(attestations, resp.Attestations...)
|
||||
|
||||
return nil
|
||||
}, backoff.WithMaxRetries(bo, 3))
|
||||
|
||||
// bail if RESTWithNext errored out
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attestations = append(attestations, resp.Attestations...)
|
||||
}
|
||||
|
||||
if len(attestations) == 0 {
|
||||
|
|
@ -108,10 +130,34 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At
|
|||
return attestations, nil
|
||||
}
|
||||
|
||||
func shouldRetry(err error) bool {
|
||||
var httpError api.HTTPError
|
||||
if errors.As(err, &httpError) {
|
||||
if httpError.StatusCode >= 500 && httpError.StatusCode <= 599 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *LiveClient) getTrustDomain(url string) (string, error) {
|
||||
var resp MetaResponse
|
||||
|
||||
err := c.api.REST(c.host, http.MethodGet, url, nil, &resp)
|
||||
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)
|
||||
err := backoff.Retry(func() error {
|
||||
restErr := c.api.REST(c.host, http.MethodGet, url, nil, &resp)
|
||||
if restErr != nil {
|
||||
if shouldRetry(restErr) {
|
||||
return restErr
|
||||
} else {
|
||||
return backoff.Permanent(restErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}, backoff.WithMaxRetries(bo, 3))
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -204,3 +204,62 @@ func TestGetTrustDomain(t *testing.T) {
|
|||
})
|
||||
|
||||
}
|
||||
|
||||
func TestGetAttestationsRetries(t *testing.T) {
|
||||
getAttestationRetryInterval = 0
|
||||
|
||||
fetcher := mockDataGenerator{
|
||||
NumAttestations: 5,
|
||||
}
|
||||
|
||||
c := &LiveClient{
|
||||
api: mockAPIClient{
|
||||
OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(),
|
||||
},
|
||||
logger: io.NewTestHandler(),
|
||||
}
|
||||
|
||||
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
|
||||
require.NoError(t, err)
|
||||
|
||||
// assert the error path was executed; because this is a paged
|
||||
// request, it should have errored twice
|
||||
fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 2)
|
||||
|
||||
// but we still successfully got the right data
|
||||
require.Equal(t, len(attestations), 10)
|
||||
bundle := (attestations)[0].Bundle
|
||||
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
|
||||
|
||||
// same test as above, but for GetByOwnerAndDigest:
|
||||
attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit)
|
||||
require.NoError(t, err)
|
||||
|
||||
// because we haven't reset the mock, we have added 2 more failed requests
|
||||
fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 4)
|
||||
|
||||
require.Equal(t, len(attestations), 10)
|
||||
bundle = (attestations)[0].Bundle
|
||||
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
|
||||
}
|
||||
|
||||
// test total retries
|
||||
func TestGetAttestationsMaxRetries(t *testing.T) {
|
||||
getAttestationRetryInterval = 0
|
||||
|
||||
fetcher := mockDataGenerator{
|
||||
NumAttestations: 5,
|
||||
}
|
||||
|
||||
c := &LiveClient{
|
||||
api: mockAPIClient{
|
||||
OnRESTWithNext: fetcher.OnREST500ErrorHandler(),
|
||||
},
|
||||
logger: io.NewTestHandler(),
|
||||
}
|
||||
|
||||
_, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
|
||||
require.Error(t, err)
|
||||
|
||||
fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
cliAPI "github.com/cli/cli/v2/api"
|
||||
ghAPI "github.com/cli/go-gh/v2/pkg/api"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type mockAPIClient struct {
|
||||
|
|
@ -22,14 +26,15 @@ func (m mockAPIClient) REST(hostname, method, p string, body io.Reader, data int
|
|||
}
|
||||
|
||||
type mockDataGenerator struct {
|
||||
mock.Mock
|
||||
NumAttestations int
|
||||
}
|
||||
|
||||
func (m mockDataGenerator) OnRESTSuccess(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
func (m *mockDataGenerator) OnRESTSuccess(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
return m.OnRESTWithNextSuccessHelper(hostname, method, p, body, data, false)
|
||||
}
|
||||
|
||||
func (m mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
func (m *mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
// if path doesn't contain after, it means first time hitting the mock server
|
||||
// so return the first page and return the link header in the response
|
||||
if !strings.Contains(p, "after") {
|
||||
|
|
@ -40,7 +45,37 @@ func (m mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string,
|
|||
return m.OnRESTWithNextSuccessHelper(hostname, method, p, body, data, false)
|
||||
}
|
||||
|
||||
func (m mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p string, body io.Reader, data interface{}, hasNext bool) (string, error) {
|
||||
// Returns a func that just calls OnRESTSuccessWithNextPage but half the time
|
||||
// it returns a 500 error.
|
||||
func (m *mockDataGenerator) FlakyOnRESTSuccessWithNextPageHandler() func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
// set up the flake counter
|
||||
m.On("FlakyOnRESTSuccessWithNextPage:error").Return()
|
||||
|
||||
count := 0
|
||||
return func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
if count%2 == 0 {
|
||||
m.MethodCalled("FlakyOnRESTSuccessWithNextPage:error")
|
||||
|
||||
count = count + 1
|
||||
return "", cliAPI.HTTPError{HTTPError: &ghAPI.HTTPError{StatusCode: 500}}
|
||||
} else {
|
||||
count = count + 1
|
||||
return m.OnRESTSuccessWithNextPage(hostname, method, p, body, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// always returns a 500
|
||||
func (m *mockDataGenerator) OnREST500ErrorHandler() func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
m.On("OnREST500Error").Return()
|
||||
return func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
m.MethodCalled("OnREST500Error")
|
||||
|
||||
return "", cliAPI.HTTPError{HTTPError: &ghAPI.HTTPError{StatusCode: 500}}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p string, body io.Reader, data interface{}, hasNext bool) (string, error) {
|
||||
atts := make([]*Attestation, m.NumAttestations)
|
||||
for j := 0; j < m.NumAttestations; j++ {
|
||||
att := makeTestAttestation()
|
||||
|
|
@ -70,7 +105,7 @@ func (m mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p strin
|
|||
return "", nil
|
||||
}
|
||||
|
||||
func (m mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
func (m *mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
resp := AttestationsResponse{
|
||||
Attestations: make([]*Attestation, 0),
|
||||
}
|
||||
|
|
@ -89,7 +124,7 @@ func (m mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p stri
|
|||
return "", nil
|
||||
}
|
||||
|
||||
func (m mockDataGenerator) OnRESTWithNextError(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
func (m *mockDataGenerator) OnRESTWithNextError(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
|
||||
return "", errors.New("failed to get attestations")
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue