use mock to assert number of http calls
Signed-off-by: Meredith Lancaster <malancas@github.com>
This commit is contained in:
parent
e34e188ee2
commit
ecf55c6c16
2 changed files with 11 additions and 8 deletions
|
|
@ -194,7 +194,7 @@ func TestFetchBundlesByURL(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Len(t, fetched, 2)
|
||||
require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", fetched[0].Bundle.GetMediaType())
|
||||
require.Equal(t, 2, mockHTTPClient.TimesCalled())
|
||||
mockHTTPClient.AssertNumberOfCalls(t, "OnGet", 2)
|
||||
}
|
||||
|
||||
func TestFetchBundlesByURL_Fail(t *testing.T) {
|
||||
|
|
@ -225,7 +225,7 @@ func TestFetchBundleByURL(t *testing.T) {
|
|||
bundle, err := c.fetchBundleByURL(&attestation)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", bundle.GetMediaType())
|
||||
require.Equal(t, 1, mockHTTPClient.TimesCalled())
|
||||
mockHTTPClient.AssertNumberOfCalls(t, "OnGet", 1)
|
||||
}
|
||||
|
||||
func TestFetchBundleByURL_FetchByURLFail(t *testing.T) {
|
||||
|
|
@ -240,7 +240,7 @@ func TestFetchBundleByURL_FetchByURLFail(t *testing.T) {
|
|||
bundle, err := c.fetchBundleByURL(&attestation)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, bundle)
|
||||
require.Equal(t, 1, mockHTTPClient.TimesCalled())
|
||||
mockHTTPClient.AssertNumberOfCalls(t, "OnGet", 1)
|
||||
}
|
||||
|
||||
func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) {
|
||||
|
|
@ -255,7 +255,7 @@ func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) {
|
|||
bundle, err := c.fetchBundleByURL(&attestation)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", bundle.GetMediaType())
|
||||
require.Equal(t, 0, mockHTTPClient.TimesCalled())
|
||||
mockHTTPClient.AssertNotCalled(t, "OnGet")
|
||||
}
|
||||
|
||||
func TestGetTrustDomain(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -9,26 +9,33 @@ import (
|
|||
|
||||
"github.com/cli/cli/v2/pkg/cmd/attestation/test/data"
|
||||
"github.com/golang/snappy"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type mockHttpClient struct {
|
||||
mock.Mock
|
||||
mutex sync.RWMutex
|
||||
currNumCalls int
|
||||
alwaysFail bool
|
||||
failAfterNumCalls int
|
||||
OnGet func(url string) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (m *mockHttpClient) Get(url string) (*http.Response, error) {
|
||||
m.On("OnGet").Return()
|
||||
m.mutex.Lock()
|
||||
m.currNumCalls++
|
||||
m.mutex.Unlock()
|
||||
|
||||
if m.alwaysFail || (m.failAfterNumCalls > 0 && m.currNumCalls > m.failAfterNumCalls) {
|
||||
m.MethodCalled("OnGet")
|
||||
return &http.Response{
|
||||
StatusCode: 500,
|
||||
}, fmt.Errorf("failed to fetch with %s", url)
|
||||
}
|
||||
|
||||
m.MethodCalled("OnGet")
|
||||
|
||||
var compressed []byte
|
||||
compressed = snappy.Encode(compressed, data.SigstoreBundleRaw)
|
||||
return &http.Response{
|
||||
|
|
@ -37,10 +44,6 @@ func (m *mockHttpClient) Get(url string) (*http.Response, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockHttpClient) TimesCalled() int {
|
||||
return m.currNumCalls
|
||||
}
|
||||
|
||||
func FailHTTPClient() mockHttpClient {
|
||||
return mockHttpClient{
|
||||
alwaysFail: true,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue