diff --git a/go.mod b/go.mod index cd7a82036..5661ac091 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/distribution/reference v0.5.0 github.com/gabriel-vasile/mimetype v1.4.7 github.com/gdamore/tcell/v2 v2.5.4 + github.com/golang/snappy v0.0.4 github.com/google/go-cmp v0.6.0 github.com/google/go-containerregistry v0.20.2 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 diff --git a/go.sum b/go.sum index 0ec9bcd6a..d17d1251d 100644 --- a/go.sum +++ b/go.sum @@ -198,6 +198,8 @@ github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/certificate-transparency-go v1.2.1 h1:4iW/NwzqOqYEEoCBEFP+jPbBXbLqMpq3CifMyOnDUME= github.com/google/certificate-transparency-go v1.2.1/go.mod h1:bvn/ytAccv+I6+DGkqpvSsEdiVGramgaSC6RD3tEmeE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= diff --git a/pkg/cmd/attestation/api/attestation.go b/pkg/cmd/attestation/api/attestation.go index ea055b293..16b062a96 100644 --- a/pkg/cmd/attestation/api/attestation.go +++ b/pkg/cmd/attestation/api/attestation.go @@ -25,7 +25,8 @@ func newErrNoAttestations(name, digest string) ErrNoAttestations { } type Attestation struct { - Bundle *bundle.Bundle `json:"bundle"` + Bundle *bundle.Bundle `json:"bundle"` + BundleURL string `json:"bundle_url"` } type AttestationsResponse struct { diff --git a/pkg/cmd/attestation/api/client.go b/pkg/cmd/attestation/api/client.go index f47b5f759..37579b7bc 100644 --- a/pkg/cmd/attestation/api/client.go +++ b/pkg/cmd/attestation/api/client.go @@ -11,6 +11,11 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/cli/cli/v2/api" ioconfig "github.com/cli/cli/v2/pkg/cmd/attestation/io" + "github.com/golang/snappy" + v1 "github.com/sigstore/protobuf-specs/gen/pb-go/bundle/v1" + "github.com/sigstore/sigstore-go/pkg/bundle" + "golang.org/x/sync/errgroup" + "google.golang.org/protobuf/encoding/protojson" ) const ( @@ -19,11 +24,20 @@ const ( maxLimitForFetch = 100 ) -type apiClient interface { +// Allow injecting backoff interval in tests. +var getAttestationRetryInterval = time.Millisecond * 200 + +// githubApiClient makes REST calls to the GitHub API +type githubApiClient interface { REST(hostname, method, p string, body io.Reader, data interface{}) error RESTWithNext(hostname, method, p string, body io.Reader, data interface{}) (string, error) } +// httpClient makes HTTP calls to all non-GitHub API endpoints +type httpClient interface { + Get(url string) (*http.Response, error) +} + type Client interface { GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) @@ -31,16 +45,18 @@ type Client interface { } type LiveClient struct { - api apiClient - host string - logger *ioconfig.Handler + githubAPI githubApiClient + httpClient httpClient + host string + logger *ioconfig.Handler } func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClient { return &LiveClient{ - api: api.NewClientFromHTTP(hc), - host: strings.TrimSuffix(host, "/"), - logger: l, + githubAPI: api.NewClientFromHTTP(hc), + host: strings.TrimSuffix(host, "/"), + httpClient: hc, + logger: l, } } @@ -52,7 +68,17 @@ func (c *LiveClient) BuildRepoAndDigestURL(repo, digest string) string { // GetByRepoAndDigest fetches the attestation by repo and digest func (c *LiveClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) { url := c.BuildRepoAndDigestURL(repo, digest) - return c.getAttestations(url, repo, digest, limit) + attestations, err := c.getAttestations(url, repo, digest, limit) + if err != nil { + return nil, err + } + + bundles, err := c.fetchBundleFromAttestations(attestations) + if err != nil { + return nil, fmt.Errorf("failed to fetch bundle with URL: %w", err) + } + + return bundles, nil } func (c *LiveClient) BuildOwnerAndDigestURL(owner, digest string) string { @@ -63,7 +89,21 @@ func (c *LiveClient) BuildOwnerAndDigestURL(owner, digest string) string { // GetByOwnerAndDigest fetches attestation by owner and digest func (c *LiveClient) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) { url := c.BuildOwnerAndDigestURL(owner, digest) - return c.getAttestations(url, owner, digest, limit) + attestations, err := c.getAttestations(url, owner, digest, limit) + if err != nil { + return nil, err + } + + if len(attestations) == 0 { + return nil, newErrNoAttestations(owner, digest) + } + + bundles, err := c.fetchBundleFromAttestations(attestations) + if err != nil { + return nil, fmt.Errorf("failed to fetch bundle with URL: %w", err) + } + + return bundles, nil } // GetTrustDomain returns the current trust domain. If the default is used @@ -72,9 +112,6 @@ func (c *LiveClient) GetTrustDomain() (string, error) { return c.getTrustDomain(MetaPath) } -// Allow injecting backoff interval in tests. -var getAttestationRetryInterval = time.Millisecond * 200 - func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*Attestation, error) { c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) @@ -97,7 +134,7 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At // if no attestation or less than limit, then keep fetching for url != "" && len(attestations) < limit { err := backoff.Retry(func() error { - newURL, restErr := c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp) + newURL, restErr := c.githubAPI.RESTWithNext(c.host, http.MethodGet, url, nil, &resp) if restErr != nil { if shouldRetry(restErr) { @@ -130,6 +167,77 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At return attestations, nil } +func (c *LiveClient) fetchBundleFromAttestations(attestations []*Attestation) ([]*Attestation, error) { + fetched := make([]*Attestation, len(attestations)) + g := errgroup.Group{} + for i, a := range attestations { + g.Go(func() error { + if a.Bundle == nil && a.BundleURL == "" { + return fmt.Errorf("attestation has no bundle or bundle URL") + } + + // for now, we fallback to the bundle field if the bundle URL is empty + if a.BundleURL == "" { + c.logger.VerbosePrintf("Bundle URL is empty. Falling back to bundle field\n\n") + fetched[i] = &Attestation{ + Bundle: a.Bundle, + } + return nil + } + + // otherwise fetch the bundle with the provided URL + b, err := c.GetBundle(a.BundleURL) + if err != nil { + return fmt.Errorf("failed to fetch bundle with URL: %w", err) + } + fetched[i] = &Attestation{ + Bundle: b, + } + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, err + } + + return fetched, nil +} + +func (c *LiveClient) GetBundle(url string) (*bundle.Bundle, error) { + c.logger.VerbosePrintf("Fetching attestation bundle with bundle URL\n\n") + + resp, err := c.httpClient.Get(url) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read blob storage response body: %w", err) + } + + var out []byte + decompressed, err := snappy.Decode(out, body) + if err != nil { + return nil, fmt.Errorf("failed to decompress with snappy: %w", err) + } + + var pbBundle v1.Bundle + if err = protojson.Unmarshal(decompressed, &pbBundle); err != nil { + return nil, fmt.Errorf("failed to unmarshal to bundle: %w", err) + } + + c.logger.VerbosePrintf("Successfully fetched bundle\n\n") + + return bundle.NewBundle(&pbBundle) +} + func shouldRetry(err error) bool { var httpError api.HTTPError if errors.As(err, &httpError) { @@ -146,7 +254,7 @@ func (c *LiveClient) getTrustDomain(url string) (string, error) { bo := backoff.NewConstantBackOff(getAttestationRetryInterval) err := backoff.Retry(func() error { - restErr := c.api.REST(c.host, http.MethodGet, url, nil, &resp) + restErr := c.githubAPI.REST(c.host, http.MethodGet, url, nil, &resp) if restErr != nil { if shouldRetry(restErr) { return restErr diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index adac00598..65f8d59ca 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/cli/cli/v2/pkg/cmd/attestation/io" + "github.com/cli/cli/v2/pkg/cmd/attestation/test/data" "github.com/stretchr/testify/require" ) @@ -20,20 +21,24 @@ func NewClientWithMockGHClient(hasNextPage bool) Client { } l := io.NewTestHandler() + httpClient := &mockHttpClient{} + if hasNextPage { return &LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnRESTSuccessWithNextPage, }, - logger: l, + httpClient: httpClient, + logger: l, } } return &LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnRESTSuccess, }, - logger: l, + httpClient: httpClient, + logger: l, } } @@ -134,11 +139,13 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) { NumAttestations: 5, } + httpClient := &mockHttpClient{} c := LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnRESTWithNextNoAttestations, }, - logger: io.NewTestHandler(), + httpClient: httpClient, + logger: io.NewTestHandler(), } attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) @@ -158,7 +165,7 @@ func TestGetByDigest_Error(t *testing.T) { } c := LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnRESTWithNextError, }, logger: io.NewTestHandler(), @@ -173,6 +180,86 @@ func TestGetByDigest_Error(t *testing.T) { require.Nil(t, attestations) } +func TestFetchBundleFromAttestations(t *testing.T) { + httpClient := &mockHttpClient{} + client := LiveClient{ + httpClient: httpClient, + logger: io.NewTestHandler(), + } + + att1 := makeTestAttestation() + att2 := makeTestAttestation() + attestations := []*Attestation{&att1, &att2} + fetched, err := client.fetchBundleFromAttestations(attestations) + require.NoError(t, err) + require.Len(t, fetched, 2) + require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", fetched[0].Bundle.GetMediaType()) + httpClient.AssertNumberOfCalls(t, "OnGetSuccess", 2) +} + +func TestFetchBundleFromAttestations_InvalidAttestation(t *testing.T) { + httpClient := &mockHttpClient{} + client := LiveClient{ + httpClient: httpClient, + logger: io.NewTestHandler(), + } + + att1 := Attestation{} + attestations := []*Attestation{&att1} + fetched, err := client.fetchBundleFromAttestations(attestations) + require.Error(t, err) + require.Nil(t, fetched, 2) +} + +func TestFetchBundleFromAttestations_Fail(t *testing.T) { + httpClient := &failAfterOneCallHttpClient{} + + c := &LiveClient{ + httpClient: httpClient, + logger: io.NewTestHandler(), + } + + att1 := makeTestAttestation() + att2 := makeTestAttestation() + attestations := []*Attestation{&att1, &att2} + fetched, err := c.fetchBundleFromAttestations(attestations) + require.Error(t, err) + require.Nil(t, fetched) + httpClient.AssertNumberOfCalls(t, "OnGetFailAfterOneCall", 2) +} + +func TestFetchBundleFromAttestations_FetchByURLFail(t *testing.T) { + mockHTTPClient := &failHttpClient{} + + c := &LiveClient{ + httpClient: mockHTTPClient, + logger: io.NewTestHandler(), + } + + a := makeTestAttestation() + attestations := []*Attestation{&a} + bundle, err := c.fetchBundleFromAttestations(attestations) + require.Error(t, err) + require.Nil(t, bundle) + mockHTTPClient.AssertNumberOfCalls(t, "OnGetFail", 1) +} + +func TestFetchBundleByURL_FallbackToBundleField(t *testing.T) { + mockHTTPClient := &mockHttpClient{} + + c := &LiveClient{ + httpClient: mockHTTPClient, + logger: io.NewTestHandler(), + } + + a := Attestation{Bundle: data.SigstoreBundle(t)} + attestations := []*Attestation{&a} + fetched, err := c.fetchBundleFromAttestations(attestations) + require.NoError(t, err) + require.Equal(t, "application/vnd.dev.sigstore.bundle.v0.3+json", fetched[0].Bundle.GetMediaType()) + mockHTTPClient.AssertNotCalled(t, "OnGetSuccess") +} + func TestGetTrustDomain(t *testing.T) { fetcher := mockMetaGenerator{ TrustDomain: "foo", @@ -180,7 +267,7 @@ func TestGetTrustDomain(t *testing.T) { t.Run("with returned trust domain", func(t *testing.T) { c := LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnREST: fetcher.OnREST, }, logger: io.NewTestHandler(), @@ -193,7 +280,7 @@ func TestGetTrustDomain(t *testing.T) { t.Run("with error", func(t *testing.T) { c := LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnREST: fetcher.OnRESTError, }, logger: io.NewTestHandler(), @@ -213,10 +300,11 @@ func TestGetAttestationsRetries(t *testing.T) { } c := &LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(), }, - logger: io.NewTestHandler(), + httpClient: &mockHttpClient{}, + logger: io.NewTestHandler(), } attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) @@ -252,7 +340,7 @@ func TestGetAttestationsMaxRetries(t *testing.T) { } c := &LiveClient{ - api: mockAPIClient{ + githubAPI: mockAPIClient{ OnRESTWithNext: fetcher.OnREST500ErrorHandler(), }, logger: io.NewTestHandler(), diff --git a/pkg/cmd/attestation/api/mock_client.go b/pkg/cmd/attestation/api/mock_client.go index bcb51c414..b2fd334c0 100644 --- a/pkg/cmd/attestation/api/mock_client.go +++ b/pkg/cmd/attestation/api/mock_client.go @@ -25,7 +25,7 @@ func (m MockClient) GetTrustDomain() (string, error) { } func makeTestAttestation() Attestation { - return Attestation{Bundle: data.SigstoreBundle(nil)} + return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"} } func OnGetByRepoAndDigestSuccess(repo, digest string, limit int) ([]*Attestation, error) { diff --git a/pkg/cmd/attestation/api/mock_apiClient_test.go b/pkg/cmd/attestation/api/mock_githubApiClient_test.go similarity index 100% rename from pkg/cmd/attestation/api/mock_apiClient_test.go rename to pkg/cmd/attestation/api/mock_githubApiClient_test.go diff --git a/pkg/cmd/attestation/api/mock_httpClient_test.go b/pkg/cmd/attestation/api/mock_httpClient_test.go new file mode 100644 index 000000000..4dc34fbf8 --- /dev/null +++ b/pkg/cmd/attestation/api/mock_httpClient_test.go @@ -0,0 +1,64 @@ +package api + +import ( + "bytes" + "fmt" + "io" + "net/http" + + "github.com/cli/cli/v2/pkg/cmd/attestation/test/data" + "github.com/golang/snappy" + "github.com/stretchr/testify/mock" +) + +type mockHttpClient struct { + mock.Mock +} + +func (m *mockHttpClient) Get(url string) (*http.Response, error) { + m.On("OnGetSuccess").Return() + m.MethodCalled("OnGetSuccess") + + var compressed []byte + compressed = snappy.Encode(compressed, data.SigstoreBundleRaw) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(compressed)), + }, nil +} + +type failHttpClient struct { + mock.Mock +} + +func (m *failHttpClient) Get(url string) (*http.Response, error) { + m.On("OnGetFail").Return() + m.MethodCalled("OnGetFail") + + return &http.Response{ + StatusCode: 500, + }, fmt.Errorf("failed to fetch with %s", url) +} + +type failAfterOneCallHttpClient struct { + mock.Mock +} + +func (m *failAfterOneCallHttpClient) Get(url string) (*http.Response, error) { + m.On("OnGetFailAfterOneCall").Return() + + if len(m.Calls) >= 1 { + m.MethodCalled("OnGetFailAfterOneCall") + return &http.Response{ + StatusCode: 500, + }, fmt.Errorf("failed to fetch with %s", url) + } + + m.MethodCalled("OnGetFailAfterOneCall") + var compressed []byte + compressed = snappy.Encode(compressed, data.SigstoreBundleRaw) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(compressed)), + }, nil +}