diff --git a/pkg/cmd/attestation/api/client.go b/pkg/cmd/attestation/api/client.go index 47cff27ed..34be1a6bb 100644 --- a/pkg/cmd/attestation/api/client.go +++ b/pkg/cmd/attestation/api/client.go @@ -26,11 +26,15 @@ const ( // Allow injecting backoff interval in tests. var getAttestationRetryInterval = time.Millisecond * 200 -type apiClient interface { +type githubApiClient interface { REST(hostname, method, p string, body io.Reader, data interface{}) error RESTWithNext(hostname, method, p string, body io.Reader, data interface{}) (string, error) } +type httpClient interface { + Get(url string) (*http.Response, error) +} + type Client interface { GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) @@ -38,27 +42,29 @@ type Client interface { } type LiveClient struct { - api apiClient - host string - logger *ioconfig.Handler + githubAPI githubApiClient + httpClient httpClient + host string + logger *ioconfig.Handler } func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClient { return &LiveClient{ - api: api.NewClientFromHTTP(hc), - host: strings.TrimSuffix(host, "/"), - logger: l, + githubAPI: api.NewClientFromHTTP(hc), + host: strings.TrimSuffix(host, "/"), + httpClient: hc, + logger: l, } } -func (c *LiveClient) BuildRepoAndDigestURL(repo, digest string) string { +func (c *LiveClient) buildRepoAndDigestURL(repo, digest string) string { repo = strings.Trim(repo, "/") return fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, repo, digest) } // GetByRepoAndDigest fetches the attestation by repo and digest func (c *LiveClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) { - url := c.BuildRepoAndDigestURL(repo, digest) + url := c.buildRepoAndDigestURL(repo, digest) attestations, err := c.getAttestations(url, repo, digest, limit) if err != nil { return nil, fmt.Errorf("failed to fetch attestation by repo and digest: %w", err) @@ -121,7 +127,7 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At // if no attestation or less than limit, then keep fetching for url != "" && len(attestations) < limit { err := backoff.Retry(func() error { - newURL, restErr := c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp) + newURL, restErr := c.githubAPI.RESTWithNext(c.host, http.MethodGet, url, nil, &resp) if restErr != nil { if shouldRetry(restErr) { @@ -170,14 +176,12 @@ func (c *LiveClient) fetchBundlesWithSASURL(attestations []*Attestation) ([]*Att func (c *LiveClient) fetchBundleWithSASURL(a *Attestation) (*bundle.Bundle, error) { if a.BundleURL == "" { - //return a.Bundle, nil - return nil, fmt.Errorf("SAS URL is empty") + return nil, fmt.Errorf("bundle URL is empty") } c.logger.VerbosePrintf("Fetching attestation bundle\n\n") - httpClient := http.DefaultClient - r, err := httpClient.Get(a.BundleURL) + r, err := c.httpClient.Get(a.BundleURL) if err != nil { return nil, err } @@ -225,7 +229,7 @@ func (c *LiveClient) getTrustDomain(url string) (string, error) { bo := backoff.NewConstantBackOff(getAttestationRetryInterval) err := backoff.Retry(func() error { - restErr := c.api.REST(c.host, http.MethodGet, url, nil, &resp) + restErr := c.githubAPI.REST(c.host, http.MethodGet, url, nil, &resp) if restErr != nil { if shouldRetry(restErr) { return restErr diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index adac00598..fb3b338bf 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -22,7 +22,7 @@ func NewClientWithMockGHClient(hasNextPage bool) Client { if hasNextPage { return &LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnRESTSuccessWithNextPage, }, logger: l, @@ -30,7 +30,7 @@ func NewClientWithMockGHClient(hasNextPage bool) Client { } return &LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnRESTSuccess, }, logger: l, @@ -50,7 +50,7 @@ func TestGetURL(t *testing.T) { } for _, data := range testData { - s := c.BuildRepoAndDigestURL(data.repo, data.digest) + s := c.buildRepoAndDigestURL(data.repo, data.digest) require.Equal(t, data.expected, s) } } @@ -135,7 +135,7 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) { } c := LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnRESTWithNextNoAttestations, }, logger: io.NewTestHandler(), @@ -158,7 +158,7 @@ func TestGetByDigest_Error(t *testing.T) { } c := LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnRESTWithNextError, }, logger: io.NewTestHandler(), @@ -180,7 +180,7 @@ func TestGetTrustDomain(t *testing.T) { t.Run("with returned trust domain", func(t *testing.T) { c := LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnREST: fetcher.OnREST, }, logger: io.NewTestHandler(), @@ -193,7 +193,7 @@ func TestGetTrustDomain(t *testing.T) { t.Run("with error", func(t *testing.T) { c := LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnREST: fetcher.OnRESTError, }, logger: io.NewTestHandler(), @@ -213,7 +213,7 @@ func TestGetAttestationsRetries(t *testing.T) { } c := &LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(), }, logger: io.NewTestHandler(), @@ -252,7 +252,7 @@ func TestGetAttestationsMaxRetries(t *testing.T) { } c := &LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnREST500ErrorHandler(), }, logger: io.NewTestHandler(),