diff --git a/api/client_test.go b/api/client_test.go index 492609e5c..4c81bf315 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -5,6 +5,8 @@ import ( "io/ioutil" "reflect" "testing" + + "github.com/cli/cli/pkg/httpmock" ) func eq(t *testing.T, got interface{}, expected interface{}) { @@ -15,7 +17,7 @@ func eq(t *testing.T, got interface{}, expected interface{}) { } func TestGraphQL(t *testing.T) { - http := &FakeHTTP{} + http := &httpmock.Registry{} client := NewClient( ReplaceTripper(http), AddHeader("Authorization", "token OTOKEN"), @@ -40,7 +42,7 @@ func TestGraphQL(t *testing.T) { } func TestGraphQLError(t *testing.T) { - http := &FakeHTTP{} + http := &httpmock.Registry{} client := NewClient(ReplaceTripper(http)) response := struct{}{} @@ -52,7 +54,7 @@ func TestGraphQLError(t *testing.T) { } func TestRESTGetDelete(t *testing.T) { - http := &FakeHTTP{} + http := &httpmock.Registry{} client := NewClient( ReplaceTripper(http), diff --git a/api/fake_http.go b/api/fake_http.go deleted file mode 100644 index d8b7506a6..000000000 --- a/api/fake_http.go +++ /dev/null @@ -1,120 +0,0 @@ -package api - -import ( - "bytes" - "fmt" - "io" - "io/ioutil" - "net/http" - "os" - "path" - "strings" -) - -// FakeHTTP provides a mechanism by which to stub HTTP responses through -type FakeHTTP struct { - // Requests stores references to sequential requests that RoundTrip has received - Requests []*http.Request - count int - responseStubs []*http.Response -} - -// StubResponse pre-records an HTTP response -func (f *FakeHTTP) StubResponse(status int, body io.Reader) { - resp := &http.Response{ - StatusCode: status, - Body: ioutil.NopCloser(body), - } - f.responseStubs = append(f.responseStubs, resp) -} - -// RoundTrip satisfies http.RoundTripper -func (f *FakeHTTP) RoundTrip(req *http.Request) (*http.Response, error) { - if len(f.responseStubs) <= f.count { - return nil, fmt.Errorf("FakeHTTP: missing response stub for request %d", f.count) - } - resp := f.responseStubs[f.count] - f.count++ - resp.Request = req - f.Requests = append(f.Requests, req) - return resp, nil -} - -func (f *FakeHTTP) StubWithFixture(status int, fixtureFileName string) func() { - fixturePath := path.Join("../test/fixtures/", fixtureFileName) - fixtureFile, _ := os.Open(fixturePath) - f.StubResponse(status, fixtureFile) - return func() { fixtureFile.Close() } -} - -func (f *FakeHTTP) StubRepoResponse(owner, repo string) { - f.StubRepoResponseWithPermission(owner, repo, "WRITE") -} - -func (f *FakeHTTP) StubRepoResponseWithPermission(owner, repo, permission string) { - body := bytes.NewBufferString(fmt.Sprintf(` - { "data": { "repo_000": { - "id": "REPOID", - "name": "%s", - "owner": {"login": "%s"}, - "defaultBranchRef": { - "name": "master" - }, - "viewerPermission": "%s" - } } } - `, repo, owner, permission)) - resp := &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(body), - } - f.responseStubs = append(f.responseStubs, resp) -} - -func (f *FakeHTTP) StubRepoResponseWithDefaultBranch(owner, repo, defaultBranch string) { - body := bytes.NewBufferString(fmt.Sprintf(` - { "data": { "repo_000": { - "id": "REPOID", - "name": "%s", - "owner": {"login": "%s"}, - "defaultBranchRef": { - "name": "%s" - }, - "viewerPermission": "READ" - } } } - `, repo, owner, defaultBranch)) - resp := &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(body), - } - f.responseStubs = append(f.responseStubs, resp) -} - -func (f *FakeHTTP) StubForkedRepoResponse(forkFullName, parentFullName string) { - forkRepo := strings.SplitN(forkFullName, "/", 2) - parentRepo := strings.SplitN(parentFullName, "/", 2) - body := bytes.NewBufferString(fmt.Sprintf(` - { "data": { "repo_000": { - "id": "REPOID2", - "name": "%s", - "owner": {"login": "%s"}, - "defaultBranchRef": { - "name": "master" - }, - "viewerPermission": "ADMIN", - "parent": { - "id": "REPOID1", - "name": "%s", - "owner": {"login": "%s"}, - "defaultBranchRef": { - "name": "master" - }, - "viewerPermission": "READ" - } - } } } - `, forkRepo[1], forkRepo[0], parentRepo[1], parentRepo[0])) - resp := &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(body), - } - f.responseStubs = append(f.responseStubs, resp) -} diff --git a/api/queries_issue_test.go b/api/queries_issue_test.go index b2dc7c819..c8ee505cf 100644 --- a/api/queries_issue_test.go +++ b/api/queries_issue_test.go @@ -7,10 +7,11 @@ import ( "testing" "github.com/cli/cli/internal/ghrepo" + "github.com/cli/cli/pkg/httpmock" ) func TestIssueList(t *testing.T) { - http := &FakeHTTP{} + http := &httpmock.Registry{} client := NewClient(ReplaceTripper(http)) http.StubResponse(200, bytes.NewBufferString(` diff --git a/api/queries_repo_test.go b/api/queries_repo_test.go index 49202f628..d8a75f4ff 100644 --- a/api/queries_repo_test.go +++ b/api/queries_repo_test.go @@ -5,10 +5,12 @@ import ( "encoding/json" "io/ioutil" "testing" + + "github.com/cli/cli/pkg/httpmock" ) func Test_RepoCreate(t *testing.T) { - http := &FakeHTTP{} + http := &httpmock.Registry{} client := NewClient(ReplaceTripper(http)) http.StubResponse(200, bytes.NewBufferString(`{}`)) diff --git a/command/pr_create_test.go b/command/pr_create_test.go index c84a8ab0c..f69c5c0c6 100644 --- a/command/pr_create_test.go +++ b/command/pr_create_test.go @@ -9,6 +9,7 @@ import ( "github.com/cli/cli/context" "github.com/cli/cli/git" + "github.com/cli/cli/pkg/httpmock" "github.com/cli/cli/test" ) @@ -408,20 +409,27 @@ func TestPRCreate_survey_defaults_multicommit(t *testing.T) { func TestPRCreate_survey_defaults_monocommit(t *testing.T) { initBlankContext("", "OWNER/REPO", "feature") http := initFakeHTTP() - http.StubRepoResponse("OWNER", "REPO") - http.StubResponse(200, bytes.NewBufferString(` - { "data": { "repository": { "forks": { "nodes": [ - ] } } } } + defer http.Verify(t) + http.Register(httpmock.GraphQL(`\bviewerPermission\b`), httpmock.StringResponse(httpmock.RepoNetworkStubResponse("OWNER", "REPO", "master", "WRITE"))) + http.Register(httpmock.GraphQL(`\bforks\(`), httpmock.StringResponse(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } `)) - http.StubResponse(200, bytes.NewBufferString(` + http.Register(httpmock.GraphQL(`\bpullRequests\(`), httpmock.StringResponse(` { "data": { "repository": { "pullRequests": { "nodes" : [ ] } } } } `)) - http.StubResponse(200, bytes.NewBufferString(` + http.Register(httpmock.GraphQL(`\bcreatePullRequest\(`), httpmock.GraphQLMutation(` { "data": { "createPullRequest": { "pullRequest": { "URL": "https://github.com/OWNER/REPO/pull/12" } } } } - `)) + `, func(inputs map[string]interface{}) { + eq(t, inputs["repositoryId"], "REPOID") + eq(t, inputs["title"], "the sky above the port") + eq(t, inputs["body"], "was the color of a television, turned to a dead channel") + eq(t, inputs["baseRefName"], "master") + eq(t, inputs["headRefName"], "feature") + })) cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() @@ -456,29 +464,6 @@ func TestPRCreate_survey_defaults_monocommit(t *testing.T) { output, err := RunCommand(`pr create`) eq(t, err, nil) - - bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body) - reqBody := struct { - Variables struct { - Input struct { - RepositoryID string - Title string - Body string - BaseRefName string - HeadRefName string - } - } - }{} - _ = json.Unmarshal(bodyBytes, &reqBody) - - expectedBody := "was the color of a television, turned to a dead channel" - - eq(t, reqBody.Variables.Input.RepositoryID, "REPOID") - eq(t, reqBody.Variables.Input.Title, "the sky above the port") - eq(t, reqBody.Variables.Input.Body, expectedBody) - eq(t, reqBody.Variables.Input.BaseRefName, "master") - eq(t, reqBody.Variables.Input.HeadRefName, "feature") - eq(t, output.String(), "https://github.com/OWNER/REPO/pull/12\n") } diff --git a/command/testing.go b/command/testing.go index 5a7f68e84..784488fd8 100644 --- a/command/testing.go +++ b/command/testing.go @@ -11,6 +11,7 @@ import ( "github.com/cli/cli/api" "github.com/cli/cli/context" "github.com/cli/cli/internal/config" + "github.com/cli/cli/pkg/httpmock" "github.com/google/shlex" "github.com/spf13/pflag" ) @@ -93,8 +94,8 @@ func initBlankContext(cfg, repo, branch string) { } } -func initFakeHTTP() *api.FakeHTTP { - http := &api.FakeHTTP{} +func initFakeHTTP() *httpmock.Registry { + http := &httpmock.Registry{} apiClientForContext = func(context.Context) (*api.Client, error) { return api.NewClient(api.ReplaceTripper(http)), nil } diff --git a/context/remote_test.go b/context/remote_test.go index 4b6838e23..9cf7a0adf 100644 --- a/context/remote_test.go +++ b/context/remote_test.go @@ -10,6 +10,7 @@ import ( "github.com/cli/cli/api" "github.com/cli/cli/git" "github.com/cli/cli/internal/ghrepo" + "github.com/cli/cli/pkg/httpmock" ) func eq(t *testing.T, got interface{}, expected interface{}) { @@ -70,7 +71,7 @@ func Test_translateRemotes(t *testing.T) { } func Test_resolvedRemotes_triangularSetup(t *testing.T) { - http := &api.FakeHTTP{} + http := &httpmock.Registry{} apiClient := api.NewClient(api.ReplaceTripper(http)) http.StubResponse(200, bytes.NewBufferString(` @@ -137,7 +138,7 @@ func Test_resolvedRemotes_triangularSetup(t *testing.T) { } func Test_resolvedRemotes_forkLookup(t *testing.T) { - http := &api.FakeHTTP{} + http := &httpmock.Registry{} apiClient := api.NewClient(api.ReplaceTripper(http)) http.StubResponse(200, bytes.NewBufferString(` diff --git a/pkg/httpmock/legacy.go b/pkg/httpmock/legacy.go new file mode 100644 index 000000000..9474c3dd8 --- /dev/null +++ b/pkg/httpmock/legacy.go @@ -0,0 +1,89 @@ +package httpmock + +import ( + "fmt" + "io" + "net/http" + "os" + "path" + "strings" +) + +// TODO: clean up methods in this file when there are no more callers + +func (r *Registry) StubResponse(status int, body io.Reader) { + r.Register(MatchAny, func(*http.Request) (*http.Response, error) { + return httpResponse(status, body), nil + }) +} + +func (r *Registry) StubWithFixture(status int, fixtureFileName string) func() { + fixturePath := path.Join("../test/fixtures/", fixtureFileName) + fixtureFile, err := os.Open(fixturePath) + r.Register(MatchAny, func(*http.Request) (*http.Response, error) { + if err != nil { + return nil, err + } + return httpResponse(200, fixtureFile), nil + }) + return func() { + if err == nil { + fixtureFile.Close() + } + } +} + +func (r *Registry) StubRepoResponse(owner, repo string) { + r.StubRepoResponseWithPermission(owner, repo, "WRITE") +} + +func (r *Registry) StubRepoResponseWithPermission(owner, repo, permission string) { + r.Register(MatchAny, StringResponse(RepoNetworkStubResponse(owner, repo, "master", permission))) +} + +func (r *Registry) StubRepoResponseWithDefaultBranch(owner, repo, defaultBranch string) { + r.Register(MatchAny, StringResponse(RepoNetworkStubResponse(owner, repo, defaultBranch, "WRITE"))) +} + +func (r *Registry) StubForkedRepoResponse(ownRepo, parentRepo string) { + r.Register(MatchAny, StringResponse(RepoNetworkStubForkResponse(ownRepo, parentRepo))) +} + +func RepoNetworkStubResponse(owner, repo, defaultBranch, permission string) string { + return fmt.Sprintf(` + { "data": { "repo_000": { + "id": "REPOID", + "name": "%s", + "owner": {"login": "%s"}, + "defaultBranchRef": { + "name": "%s" + }, + "viewerPermission": "%s" + } } } + `, repo, owner, defaultBranch, permission) +} + +func RepoNetworkStubForkResponse(forkFullName, parentFullName string) string { + forkRepo := strings.SplitN(forkFullName, "/", 2) + parentRepo := strings.SplitN(parentFullName, "/", 2) + return fmt.Sprintf(` + { "data": { "repo_000": { + "id": "REPOID2", + "name": "%s", + "owner": {"login": "%s"}, + "defaultBranchRef": { + "name": "master" + }, + "viewerPermission": "ADMIN", + "parent": { + "id": "REPOID1", + "name": "%s", + "owner": {"login": "%s"}, + "defaultBranchRef": { + "name": "master" + }, + "viewerPermission": "READ" + } + } } } + `, forkRepo[1], forkRepo[0], parentRepo[1], parentRepo[0]) +} diff --git a/pkg/httpmock/registry.go b/pkg/httpmock/registry.go new file mode 100644 index 000000000..486d79a06 --- /dev/null +++ b/pkg/httpmock/registry.go @@ -0,0 +1,70 @@ +package httpmock + +import ( + "fmt" + "net/http" + "sync" +) + +type Registry struct { + mu sync.Mutex + stubs []*Stub + Requests []*http.Request +} + +func (r *Registry) Register(m Matcher, resp Responder) { + r.stubs = append(r.stubs, &Stub{ + Matcher: m, + Responder: resp, + }) +} + +type Testing interface { + Errorf(string, ...interface{}) +} + +func (r *Registry) Verify(t Testing) { + n := 0 + for _, s := range r.stubs { + if !s.matched { + n++ + } + } + if n > 0 { + // NOTE: stubs offer no useful reflection, so we can't print details + // about dead stubs and what they were trying to match + t.Errorf("%d unmatched HTTP stubs", n) + } +} + +// RoundTrip satisfies http.RoundTripper +func (r *Registry) RoundTrip(req *http.Request) (*http.Response, error) { + var stub *Stub + + r.mu.Lock() + for _, s := range r.stubs { + if s.matched || !s.Matcher(req) { + continue + } + // TODO: reinstate this check once the legacy layer has been cleaned up + // if stub != nil { + // r.mu.Unlock() + // return nil, fmt.Errorf("more than 1 stub matched %v", req) + // } + stub = s + break // TODO: remove + } + if stub != nil { + stub.matched = true + } + + if stub == nil { + r.mu.Unlock() + return nil, fmt.Errorf("no registered stubs matched %v", req) + } + + r.Requests = append(r.Requests, req) + r.mu.Unlock() + + return stub.Responder(req) +} diff --git a/pkg/httpmock/stub.go b/pkg/httpmock/stub.go new file mode 100644 index 000000000..b27ceea6c --- /dev/null +++ b/pkg/httpmock/stub.go @@ -0,0 +1,96 @@ +package httpmock + +import ( + "bytes" + "encoding/json" + "io" + "io/ioutil" + "net/http" + "regexp" + "strings" +) + +type Matcher func(req *http.Request) bool +type Responder func(req *http.Request) (*http.Response, error) + +type Stub struct { + matched bool + Matcher Matcher + Responder Responder +} + +func MatchAny(*http.Request) bool { + return true +} + +func GraphQL(q string) Matcher { + re := regexp.MustCompile(q) + + return func(req *http.Request) bool { + if !strings.EqualFold(req.Method, "POST") { + return false + } + if req.URL.Path != "/graphql" { + return false + } + + var bodyData struct { + Query string + } + _ = decodeJSONBody(req, &bodyData) + + return re.MatchString(bodyData.Query) + } +} + +func readBody(req *http.Request) ([]byte, error) { + bodyCopy := &bytes.Buffer{} + r := io.TeeReader(req.Body, bodyCopy) + req.Body = ioutil.NopCloser(bodyCopy) + return ioutil.ReadAll(r) +} + +func decodeJSONBody(req *http.Request, dest interface{}) error { + b, err := readBody(req) + if err != nil { + return err + } + return json.Unmarshal(b, dest) +} + +func StringResponse(body string) Responder { + return func(*http.Request) (*http.Response, error) { + return httpResponse(200, bytes.NewBufferString(body)), nil + } +} + +func JSONResponse(body interface{}) Responder { + return func(*http.Request) (*http.Response, error) { + b, _ := json.Marshal(body) + return httpResponse(200, bytes.NewBuffer(b)), nil + } +} + +func GraphQLMutation(body string, cb func(map[string]interface{})) Responder { + return func(req *http.Request) (*http.Response, error) { + var bodyData struct { + Variables struct { + Input map[string]interface{} + } + } + err := decodeJSONBody(req, &bodyData) + if err != nil { + return nil, err + } + cb(bodyData.Variables.Input) + + return httpResponse(200, bytes.NewBufferString(body)), nil + } +} + +func httpResponse(status int, body io.Reader) *http.Response { + return &http.Response{ + StatusCode: status, + Body: ioutil.NopCloser(body), + } +} diff --git a/update/update_test.go b/update/update_test.go index bd919f530..2fcb2d6ab 100644 --- a/update/update_test.go +++ b/update/update_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/cli/cli/api" + "github.com/cli/cli/pkg/httpmock" ) func TestCheckForUpdate(t *testing.T) { @@ -51,7 +52,7 @@ func TestCheckForUpdate(t *testing.T) { for _, s := range scenarios { t.Run(s.Name, func(t *testing.T) { - http := &api.FakeHTTP{} + http := &httpmock.Registry{} client := api.NewClient(api.ReplaceTripper(http)) http.StubResponse(200, bytes.NewBufferString(fmt.Sprintf(`{ "tag_name": "%s",