add http client test constructors

Signed-off-by: Meredith Lancaster <malancas@github.com>
This commit is contained in:
Meredith Lancaster 2025-01-07 10:43:24 -07:00
parent 9ecd90c26c
commit e34e188ee2
2 changed files with 39 additions and 21 deletions

View file

@ -21,14 +21,14 @@ func NewClientWithMockGHClient(hasNextPage bool) Client {
}
l := io.NewTestHandler()
mockHTTPClient := &mockHttpClient{}
mockHTTPClient := SuccessHTTPClient()
if hasNextPage {
return &LiveClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTSuccessWithNextPage,
},
httpClient: mockHTTPClient,
httpClient: &mockHTTPClient,
logger: l,
}
}
@ -37,7 +37,7 @@ func NewClientWithMockGHClient(hasNextPage bool) Client {
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTSuccess,
},
httpClient: mockHTTPClient,
httpClient: &mockHTTPClient,
logger: l,
}
}
@ -139,11 +139,12 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) {
NumAttestations: 5,
}
httpClient := SuccessHTTPClient()
c := LiveClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTWithNextNoAttestations,
},
httpClient: &mockHttpClient{},
httpClient: &httpClient,
logger: io.NewTestHandler(),
}
@ -193,16 +194,14 @@ func TestFetchBundlesByURL(t *testing.T) {
require.NoError(t, err)
require.Len(t, fetched, 2)
require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", fetched[0].Bundle.GetMediaType())
require.Equal(t, 2, mockHTTPClient.currNumCalls)
require.Equal(t, 2, mockHTTPClient.TimesCalled())
}
func TestFetchBundlesByURL_Fail(t *testing.T) {
mockHTTPClient := &mockHttpClient{
failAfterNumCalls: 2,
}
mockHTTPClient := HTTPClientFailsAfterNumCalls(1)
c := &LiveClient{
httpClient: mockHTTPClient,
httpClient: &mockHTTPClient,
logger: io.NewTestHandler(),
}
@ -215,10 +214,10 @@ func TestFetchBundlesByURL_Fail(t *testing.T) {
}
func TestFetchBundleByURL(t *testing.T) {
mockHTTPClient := &mockHttpClient{}
mockHTTPClient := SuccessHTTPClient()
c := &LiveClient{
httpClient: mockHTTPClient,
httpClient: &mockHTTPClient,
logger: io.NewTestHandler(),
}
@ -226,16 +225,14 @@ func TestFetchBundleByURL(t *testing.T) {
bundle, err := c.fetchBundleByURL(&attestation)
require.NoError(t, err)
require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", bundle.GetMediaType())
require.Equal(t, 1, mockHTTPClient.currNumCalls)
require.Equal(t, 1, mockHTTPClient.TimesCalled())
}
func TestFetchBundleByURL_FetchByURLFail(t *testing.T) {
mockHTTPClient := &mockHttpClient{
failAfterNumCalls: 1,
}
mockHTTPClient := FailHTTPClient()
c := &LiveClient{
httpClient: mockHTTPClient,
httpClient: &mockHTTPClient,
logger: io.NewTestHandler(),
}
@ -243,14 +240,14 @@ func TestFetchBundleByURL_FetchByURLFail(t *testing.T) {
bundle, err := c.fetchBundleByURL(&attestation)
require.Error(t, err)
require.Nil(t, bundle)
require.Equal(t, 1, mockHTTPClient.currNumCalls)
require.Equal(t, 1, mockHTTPClient.TimesCalled())
}
func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) {
mockHTTPClient := &mockHttpClient{}
mockHTTPClient := SuccessHTTPClient()
c := &LiveClient{
httpClient: mockHTTPClient,
httpClient: &mockHTTPClient,
logger: io.NewTestHandler(),
}
@ -258,7 +255,7 @@ func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) {
bundle, err := c.fetchBundleByURL(&attestation)
require.NoError(t, err)
require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", bundle.GetMediaType())
require.Equal(t, 0, mockHTTPClient.currNumCalls)
require.Equal(t, 0, mockHTTPClient.TimesCalled())
}
func TestGetTrustDomain(t *testing.T) {

View file

@ -14,6 +14,7 @@ import (
type mockHttpClient struct {
mutex sync.RWMutex
currNumCalls int
alwaysFail bool
failAfterNumCalls int
}
@ -22,7 +23,7 @@ func (m *mockHttpClient) Get(url string) (*http.Response, error) {
m.currNumCalls++
m.mutex.Unlock()
if m.failAfterNumCalls > 0 && m.currNumCalls >= m.failAfterNumCalls {
if m.alwaysFail || (m.failAfterNumCalls > 0 && m.currNumCalls > m.failAfterNumCalls) {
return &http.Response{
StatusCode: 500,
}, fmt.Errorf("failed to fetch with %s", url)
@ -35,3 +36,23 @@ func (m *mockHttpClient) Get(url string) (*http.Response, error) {
Body: io.NopCloser(bytes.NewReader(compressed)),
}, nil
}
func (m *mockHttpClient) TimesCalled() int {
return m.currNumCalls
}
func FailHTTPClient() mockHttpClient {
return mockHttpClient{
alwaysFail: true,
}
}
func SuccessHTTPClient() mockHttpClient {
return mockHttpClient{}
}
func HTTPClientFailsAfterNumCalls(numCalls int) mockHttpClient {
return mockHttpClient{
failAfterNumCalls: numCalls,
}
}