add httpClient field to LiveClient struct

Signed-off-by: Meredith Lancaster <malancas@github.com>
This commit is contained in:
Meredith Lancaster 2024-12-16 11:57:42 -07:00
parent e51b4efaa9
commit 6b95175363
2 changed files with 28 additions and 24 deletions

View file

@ -26,11 +26,15 @@ const (
// Allow injecting backoff interval in tests.
var getAttestationRetryInterval = time.Millisecond * 200
type apiClient interface {
type githubApiClient interface {
REST(hostname, method, p string, body io.Reader, data interface{}) error
RESTWithNext(hostname, method, p string, body io.Reader, data interface{}) (string, error)
}
type httpClient interface {
Get(url string) (*http.Response, error)
}
type Client interface {
GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error)
GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error)
@ -38,27 +42,29 @@ type Client interface {
}
type LiveClient struct {
api apiClient
host string
logger *ioconfig.Handler
githubAPI githubApiClient
httpClient httpClient
host string
logger *ioconfig.Handler
}
func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClient {
return &LiveClient{
api: api.NewClientFromHTTP(hc),
host: strings.TrimSuffix(host, "/"),
logger: l,
githubAPI: api.NewClientFromHTTP(hc),
host: strings.TrimSuffix(host, "/"),
httpClient: hc,
logger: l,
}
}
func (c *LiveClient) BuildRepoAndDigestURL(repo, digest string) string {
func (c *LiveClient) buildRepoAndDigestURL(repo, digest string) string {
repo = strings.Trim(repo, "/")
return fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, repo, digest)
}
// GetByRepoAndDigest fetches the attestation by repo and digest
func (c *LiveClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) {
url := c.BuildRepoAndDigestURL(repo, digest)
url := c.buildRepoAndDigestURL(repo, digest)
attestations, err := c.getAttestations(url, repo, digest, limit)
if err != nil {
return nil, fmt.Errorf("failed to fetch attestation by repo and digest: %w", err)
@ -121,7 +127,7 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At
// if no attestation or less than limit, then keep fetching
for url != "" && len(attestations) < limit {
err := backoff.Retry(func() error {
newURL, restErr := c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
newURL, restErr := c.githubAPI.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
if restErr != nil {
if shouldRetry(restErr) {
@ -170,14 +176,12 @@ func (c *LiveClient) fetchBundlesWithSASURL(attestations []*Attestation) ([]*Att
func (c *LiveClient) fetchBundleWithSASURL(a *Attestation) (*bundle.Bundle, error) {
if a.BundleURL == "" {
//return a.Bundle, nil
return nil, fmt.Errorf("SAS URL is empty")
return nil, fmt.Errorf("bundle URL is empty")
}
c.logger.VerbosePrintf("Fetching attestation bundle\n\n")
httpClient := http.DefaultClient
r, err := httpClient.Get(a.BundleURL)
r, err := c.httpClient.Get(a.BundleURL)
if err != nil {
return nil, err
}
@ -225,7 +229,7 @@ func (c *LiveClient) getTrustDomain(url string) (string, error) {
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)
err := backoff.Retry(func() error {
restErr := c.api.REST(c.host, http.MethodGet, url, nil, &resp)
restErr := c.githubAPI.REST(c.host, http.MethodGet, url, nil, &resp)
if restErr != nil {
if shouldRetry(restErr) {
return restErr

View file

@ -22,7 +22,7 @@ func NewClientWithMockGHClient(hasNextPage bool) Client {
if hasNextPage {
return &LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTSuccessWithNextPage,
},
logger: l,
@ -30,7 +30,7 @@ func NewClientWithMockGHClient(hasNextPage bool) Client {
}
return &LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTSuccess,
},
logger: l,
@ -50,7 +50,7 @@ func TestGetURL(t *testing.T) {
}
for _, data := range testData {
s := c.BuildRepoAndDigestURL(data.repo, data.digest)
s := c.buildRepoAndDigestURL(data.repo, data.digest)
require.Equal(t, data.expected, s)
}
}
@ -135,7 +135,7 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) {
}
c := LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTWithNextNoAttestations,
},
logger: io.NewTestHandler(),
@ -158,7 +158,7 @@ func TestGetByDigest_Error(t *testing.T) {
}
c := LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTWithNextError,
},
logger: io.NewTestHandler(),
@ -180,7 +180,7 @@ func TestGetTrustDomain(t *testing.T) {
t.Run("with returned trust domain", func(t *testing.T) {
c := LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnREST: fetcher.OnREST,
},
logger: io.NewTestHandler(),
@ -193,7 +193,7 @@ func TestGetTrustDomain(t *testing.T) {
t.Run("with error", func(t *testing.T) {
c := LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnREST: fetcher.OnRESTError,
},
logger: io.NewTestHandler(),
@ -213,7 +213,7 @@ func TestGetAttestationsRetries(t *testing.T) {
}
c := &LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(),
},
logger: io.NewTestHandler(),
@ -252,7 +252,7 @@ func TestGetAttestationsMaxRetries(t *testing.T) {
}
c := &LiveClient{
api: mockAPIClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnREST500ErrorHandler(),
},
logger: io.NewTestHandler(),