Added constant backoff retry to getAttestations.

This commit is contained in:
Phill MV 2024-10-21 12:10:18 -04:00
parent 664e09fdbc
commit efc1c97cf1
2 changed files with 42 additions and 3 deletions

View file

@ -1,11 +1,14 @@
package api
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/cli/cli/v2/api"
ioconfig "github.com/cli/cli/v2/pkg/cmd/attestation/io"
)
@ -69,6 +72,9 @@ func (c *LiveClient) GetTrustDomain() (string, error) {
return c.getTrustDomain(MetaPath)
}
// Allow injecting backoff interval in tests.
var getAttestationRetryInterval = time.Millisecond * 200
func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*Attestation, error) {
c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest)
@ -87,14 +93,31 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At
var attestations []*Attestation
var resp AttestationsResponse
var err error
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)
// if no attestation or less than limit, then keep fetching
for url != "" && len(attestations) < limit {
url, err = c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
err = backoff.Retry(func() error {
newURL, err := c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
if err != nil {
if shouldRetry(err) {
return err
} else {
return backoff.Permanent(err)
}
}
url = newURL
attestations = append(attestations, resp.Attestations...)
return nil
}, backoff.WithMaxRetries(bo, 3))
// bail if RESTWithNext errored out
if err != nil {
return nil, err
}
attestations = append(attestations, resp.Attestations...)
}
if len(attestations) == 0 {
@ -108,6 +131,17 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At
return attestations, nil
}
func shouldRetry(err error) bool {
var httpError api.HTTPError
if errors.As(err, &httpError) {
if httpError.StatusCode >= 500 && httpError.StatusCode <= 599 {
return true
}
}
return false
}
func (c *LiveClient) getTrustDomain(url string) (string, error) {
var resp MetaResponse

View file

@ -206,6 +206,9 @@ func TestGetTrustDomain(t *testing.T) {
}
func TestGetAttestationsRetries(t *testing.T) {
oldInterval := getAttestationRetryInterval
getAttestationRetryInterval = 0
fetcher := mockDataGenerator{
NumAttestations: 5,
}
@ -229,4 +232,6 @@ func TestGetAttestationsRetries(t *testing.T) {
require.Equal(t, 10, len(attestations))
bundle := (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
getAttestationRetryInterval = oldInterval
}