setup testing struct for test cases

Signed-off-by: Meredith Lancaster <malancas@github.com>
This commit is contained in:
Meredith Lancaster 2025-01-07 10:24:42 -07:00
parent 69865117ab
commit 9ecd90c26c
2 changed files with 71 additions and 44 deletions

View file

@ -21,15 +21,15 @@ func NewClientWithMockGHClient(hasNextPage bool) Client {
}
l := io.NewTestHandler()
mockHTTPClient := &mockHttpClient{}
if hasNextPage {
return &LiveClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTSuccessWithNextPage,
},
httpClient: &mockHttpClient{
OnGet: OnGetSuccess,
},
logger: l,
httpClient: mockHTTPClient,
logger: l,
}
}
@ -37,10 +37,8 @@ func NewClientWithMockGHClient(hasNextPage bool) Client {
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTSuccess,
},
httpClient: &mockHttpClient{
OnGet: OnGetSuccess,
},
logger: l,
httpClient: mockHTTPClient,
logger: l,
}
}
@ -145,10 +143,8 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) {
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTWithNextNoAttestations,
},
httpClient: &mockHttpClient{
OnGet: OnGetSuccess,
},
logger: io.NewTestHandler(),
httpClient: &mockHttpClient{},
logger: io.NewTestHandler(),
}
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
@ -183,12 +179,46 @@ func TestGetByDigest_Error(t *testing.T) {
require.Nil(t, attestations)
}
func TestFetchBundleByURL(t *testing.T) {
httpClient := mockHttpClient{
OnGet: OnGetSuccess,
func TestFetchBundlesByURL(t *testing.T) {
mockHTTPClient := &mockHttpClient{}
client := LiveClient{
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}
att1 := makeTestAttestation()
att2 := makeTestAttestation()
attestations := []*Attestation{&att1, &att2}
fetched, err := client.fetchBundlesByURL(attestations)
require.NoError(t, err)
require.Len(t, fetched, 2)
require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", fetched[0].Bundle.GetMediaType())
require.Equal(t, 2, mockHTTPClient.currNumCalls)
}
func TestFetchBundlesByURL_Fail(t *testing.T) {
mockHTTPClient := &mockHttpClient{
failAfterNumCalls: 2,
}
c := &LiveClient{
httpClient: &httpClient,
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}
att1 := makeTestAttestation()
att2 := makeTestAttestation()
attestations := []*Attestation{&att1, &att2}
fetched, err := c.fetchBundlesByURL(attestations)
require.Error(t, err)
require.Nil(t, fetched)
}
func TestFetchBundleByURL(t *testing.T) {
mockHTTPClient := &mockHttpClient{}
c := &LiveClient{
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}
@ -196,14 +226,16 @@ func TestFetchBundleByURL(t *testing.T) {
bundle, err := c.fetchBundleByURL(&attestation)
require.NoError(t, err)
require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", bundle.GetMediaType())
require.True(t, httpClient.called)
require.Equal(t, 1, mockHTTPClient.currNumCalls)
}
func TestFetchBundleByURL_FetchByURLFail(t *testing.T) {
httpClient := mockHttpClient{
OnGet: OnGetFail,
mockHTTPClient := &mockHttpClient{
failAfterNumCalls: 1,
}
c := &LiveClient{
httpClient: &httpClient,
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}
@ -211,15 +243,14 @@ func TestFetchBundleByURL_FetchByURLFail(t *testing.T) {
bundle, err := c.fetchBundleByURL(&attestation)
require.Error(t, err)
require.Nil(t, bundle)
require.True(t, httpClient.called)
require.Equal(t, 1, mockHTTPClient.currNumCalls)
}
func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) {
httpClient := mockHttpClient{
OnGet: OnGetSuccess,
}
mockHTTPClient := &mockHttpClient{}
c := &LiveClient{
httpClient: &httpClient,
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}
@ -227,7 +258,7 @@ func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) {
bundle, err := c.fetchBundleByURL(&attestation)
require.NoError(t, err)
require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", bundle.GetMediaType())
require.False(t, httpClient.called)
require.Equal(t, 0, mockHTTPClient.currNumCalls)
}
func TestGetTrustDomain(t *testing.T) {
@ -273,10 +304,8 @@ func TestGetAttestationsRetries(t *testing.T) {
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(),
},
httpClient: &mockHttpClient{
OnGet: OnGetSuccess,
},
logger: io.NewTestHandler(),
httpClient: &mockHttpClient{},
logger: io.NewTestHandler(),
}
attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)

View file

@ -12,28 +12,26 @@ import (
)
type mockHttpClient struct {
mutex sync.RWMutex
called bool
OnGet func(url string) (*http.Response, error)
mutex sync.RWMutex
currNumCalls int
failAfterNumCalls int
}
func (m *mockHttpClient) Get(url string) (*http.Response, error) {
m.mutex.Lock()
m.called = true
m.currNumCalls++
m.mutex.Unlock()
return m.OnGet(url)
}
func OnGetSuccess(url string) (*http.Response, error) {
compressed := snappy.Encode(nil, data.SigstoreBundleRaw)
if m.failAfterNumCalls > 0 && m.currNumCalls >= m.failAfterNumCalls {
return &http.Response{
StatusCode: 500,
}, fmt.Errorf("failed to fetch with %s", url)
}
var compressed []byte
compressed = snappy.Encode(compressed, data.SigstoreBundleRaw)
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(compressed)),
}, nil
}
func OnGetFail(url string) (*http.Response, error) {
return &http.Response{
StatusCode: 500,
}, fmt.Errorf("failed to fetch with %s", url)
}