From 9ecd90c26ceb564e0dea8f492f9f42da952ab6c8 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Tue, 7 Jan 2025 10:24:42 -0700 Subject: [PATCH] setup testing struct for test cases Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/api/client_test.go | 89 ++++++++++++------- .../attestation/api/mock_httpClient_test.go | 26 +++--- 2 files changed, 71 insertions(+), 44 deletions(-) diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index 4fb0f539a..0f91aff7e 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -21,15 +21,15 @@ func NewClientWithMockGHClient(hasNextPage bool) Client { } l := io.NewTestHandler() + mockHTTPClient := &mockHttpClient{} + if hasNextPage { return &LiveClient{ githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnRESTSuccessWithNextPage, }, - httpClient: &mockHttpClient{ - OnGet: OnGetSuccess, - }, - logger: l, + httpClient: mockHTTPClient, + logger: l, } } @@ -37,10 +37,8 @@ func NewClientWithMockGHClient(hasNextPage bool) Client { githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnRESTSuccess, }, - httpClient: &mockHttpClient{ - OnGet: OnGetSuccess, - }, - logger: l, + httpClient: mockHTTPClient, + logger: l, } } @@ -145,10 +143,8 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) { githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnRESTWithNextNoAttestations, }, - httpClient: &mockHttpClient{ - OnGet: OnGetSuccess, - }, - logger: io.NewTestHandler(), + httpClient: &mockHttpClient{}, + logger: io.NewTestHandler(), } attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) @@ -183,12 +179,46 @@ func TestGetByDigest_Error(t *testing.T) { require.Nil(t, attestations) } -func TestFetchBundleByURL(t *testing.T) { - httpClient := mockHttpClient{ - OnGet: OnGetSuccess, +func TestFetchBundlesByURL(t *testing.T) { + mockHTTPClient := &mockHttpClient{} + client := LiveClient{ + httpClient: mockHTTPClient, + logger: io.NewTestHandler(), } + + att1 := makeTestAttestation() + att2 := makeTestAttestation() + attestations := []*Attestation{&att1, &att2} + fetched, err := client.fetchBundlesByURL(attestations) + require.NoError(t, err) + require.Len(t, fetched, 2) + require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", fetched[0].Bundle.GetMediaType()) + require.Equal(t, 2, mockHTTPClient.currNumCalls) +} + +func TestFetchBundlesByURL_Fail(t *testing.T) { + mockHTTPClient := &mockHttpClient{ + failAfterNumCalls: 2, + } + c := &LiveClient{ - httpClient: &httpClient, + httpClient: mockHTTPClient, + logger: io.NewTestHandler(), + } + + att1 := makeTestAttestation() + att2 := makeTestAttestation() + attestations := []*Attestation{&att1, &att2} + fetched, err := c.fetchBundlesByURL(attestations) + require.Error(t, err) + require.Nil(t, fetched) +} + +func TestFetchBundleByURL(t *testing.T) { + mockHTTPClient := &mockHttpClient{} + + c := &LiveClient{ + httpClient: mockHTTPClient, logger: io.NewTestHandler(), } @@ -196,14 +226,16 @@ func TestFetchBundleByURL(t *testing.T) { bundle, err := c.fetchBundleByURL(&attestation) require.NoError(t, err) require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", bundle.GetMediaType()) - require.True(t, httpClient.called) + require.Equal(t, 1, mockHTTPClient.currNumCalls) } + func TestFetchBundleByURL_FetchByURLFail(t *testing.T) { - httpClient := mockHttpClient{ - OnGet: OnGetFail, + mockHTTPClient := &mockHttpClient{ + failAfterNumCalls: 1, } + c := &LiveClient{ - httpClient: &httpClient, + httpClient: mockHTTPClient, logger: io.NewTestHandler(), } @@ -211,15 +243,14 @@ func TestFetchBundleByURL_FetchByURLFail(t *testing.T) { bundle, err := c.fetchBundleByURL(&attestation) require.Error(t, err) require.Nil(t, bundle) - require.True(t, httpClient.called) + require.Equal(t, 1, mockHTTPClient.currNumCalls) } func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) { - httpClient := mockHttpClient{ - OnGet: OnGetSuccess, - } + mockHTTPClient := &mockHttpClient{} + c := &LiveClient{ - httpClient: &httpClient, + httpClient: mockHTTPClient, logger: io.NewTestHandler(), } @@ -227,7 +258,7 @@ func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) { bundle, err := c.fetchBundleByURL(&attestation) require.NoError(t, err) require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", bundle.GetMediaType()) - require.False(t, httpClient.called) + require.Equal(t, 0, mockHTTPClient.currNumCalls) } func TestGetTrustDomain(t *testing.T) { @@ -273,10 +304,8 @@ func TestGetAttestationsRetries(t *testing.T) { githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(), }, - httpClient: &mockHttpClient{ - OnGet: OnGetSuccess, - }, - logger: io.NewTestHandler(), + httpClient: &mockHttpClient{}, + logger: io.NewTestHandler(), } attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) diff --git a/pkg/cmd/attestation/api/mock_httpClient_test.go b/pkg/cmd/attestation/api/mock_httpClient_test.go index 2914d7887..3218ed508 100644 --- a/pkg/cmd/attestation/api/mock_httpClient_test.go +++ b/pkg/cmd/attestation/api/mock_httpClient_test.go @@ -12,28 +12,26 @@ import ( ) type mockHttpClient struct { - mutex sync.RWMutex - called bool - OnGet func(url string) (*http.Response, error) + mutex sync.RWMutex + currNumCalls int + failAfterNumCalls int } func (m *mockHttpClient) Get(url string) (*http.Response, error) { m.mutex.Lock() - m.called = true + m.currNumCalls++ m.mutex.Unlock() - return m.OnGet(url) -} -func OnGetSuccess(url string) (*http.Response, error) { - compressed := snappy.Encode(nil, data.SigstoreBundleRaw) + if m.failAfterNumCalls > 0 && m.currNumCalls >= m.failAfterNumCalls { + return &http.Response{ + StatusCode: 500, + }, fmt.Errorf("failed to fetch with %s", url) + } + + var compressed []byte + compressed = snappy.Encode(compressed, data.SigstoreBundleRaw) return &http.Response{ StatusCode: 200, Body: io.NopCloser(bytes.NewReader(compressed)), }, nil } - -func OnGetFail(url string) (*http.Response, error) { - return &http.Response{ - StatusCode: 500, - }, fmt.Errorf("failed to fetch with %s", url) -}