From ecf55c6c16f50279e98526b2a953630832e77a31 Mon Sep 17 00:00:00 2001 From: Meredith Lancaster Date: Tue, 7 Jan 2025 10:54:17 -0700 Subject: [PATCH] use mock to assert number of http calls Signed-off-by: Meredith Lancaster --- pkg/cmd/attestation/api/client_test.go | 8 ++++---- pkg/cmd/attestation/api/mock_httpClient_test.go | 11 +++++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index 23dcb77d9..3572689f8 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -194,7 +194,7 @@ func TestFetchBundlesByURL(t *testing.T) { require.NoError(t, err) require.Len(t, fetched, 2) require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", fetched[0].Bundle.GetMediaType()) - require.Equal(t, 2, mockHTTPClient.TimesCalled()) + mockHTTPClient.AssertNumberOfCalls(t, "OnGet", 2) } func TestFetchBundlesByURL_Fail(t *testing.T) { @@ -225,7 +225,7 @@ func TestFetchBundleByURL(t *testing.T) { bundle, err := c.fetchBundleByURL(&attestation) require.NoError(t, err) require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", bundle.GetMediaType()) - require.Equal(t, 1, mockHTTPClient.TimesCalled()) + mockHTTPClient.AssertNumberOfCalls(t, "OnGet", 1) } func TestFetchBundleByURL_FetchByURLFail(t *testing.T) { @@ -240,7 +240,7 @@ func TestFetchBundleByURL_FetchByURLFail(t *testing.T) { bundle, err := c.fetchBundleByURL(&attestation) require.Error(t, err) require.Nil(t, bundle) - require.Equal(t, 1, mockHTTPClient.TimesCalled()) + mockHTTPClient.AssertNumberOfCalls(t, "OnGet", 1) } func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) { @@ -255,7 +255,7 @@ func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) { bundle, err := c.fetchBundleByURL(&attestation) require.NoError(t, err) require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", bundle.GetMediaType()) - require.Equal(t, 0, mockHTTPClient.TimesCalled()) + mockHTTPClient.AssertNotCalled(t, "OnGet") } func TestGetTrustDomain(t *testing.T) { diff --git a/pkg/cmd/attestation/api/mock_httpClient_test.go b/pkg/cmd/attestation/api/mock_httpClient_test.go index 345d2c5c2..dd0c139d5 100644 --- a/pkg/cmd/attestation/api/mock_httpClient_test.go +++ b/pkg/cmd/attestation/api/mock_httpClient_test.go @@ -9,26 +9,33 @@ import ( "github.com/cli/cli/v2/pkg/cmd/attestation/test/data" "github.com/golang/snappy" + "github.com/stretchr/testify/mock" ) type mockHttpClient struct { + mock.Mock mutex sync.RWMutex currNumCalls int alwaysFail bool failAfterNumCalls int + OnGet func(url string) (*http.Response, error) } func (m *mockHttpClient) Get(url string) (*http.Response, error) { + m.On("OnGet").Return() m.mutex.Lock() m.currNumCalls++ m.mutex.Unlock() if m.alwaysFail || (m.failAfterNumCalls > 0 && m.currNumCalls > m.failAfterNumCalls) { + m.MethodCalled("OnGet") return &http.Response{ StatusCode: 500, }, fmt.Errorf("failed to fetch with %s", url) } + m.MethodCalled("OnGet") + var compressed []byte compressed = snappy.Encode(compressed, data.SigstoreBundleRaw) return &http.Response{ @@ -37,10 +44,6 @@ func (m *mockHttpClient) Get(url string) (*http.Response, error) { }, nil } -func (m *mockHttpClient) TimesCalled() int { - return m.currNumCalls -} - func FailHTTPClient() mockHttpClient { return mockHttpClient{ alwaysFail: true,