package shared import ( "net/http" "testing" "github.com/cli/cli/api" "github.com/cli/cli/context" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/httpmock" ) func TestPRFromArgs(t *testing.T) { type args struct { baseRepoFn func() (ghrepo.Interface, error) branchFn func() (string, error) remotesFn func() (context.Remotes, error) selector string } tests := []struct { name string args args httpStub func(*httpmock.Registry) wantPR int wantRepo string wantErr bool }{ { name: "number argument", args: args{ selector: "13", baseRepoFn: func() (ghrepo.Interface, error) { return ghrepo.FromFullName("OWNER/REPO") }, }, httpStub: func(r *httpmock.Registry) { r.Register( httpmock.GraphQL(`query PullRequestByNumber\b`), httpmock.StringResponse(`{"data":{"repository":{ "pullRequest":{"number":13} }}}`)) }, wantPR: 13, wantRepo: "https://github.com/OWNER/REPO", }, { name: "number with hash argument", args: args{ selector: "#13", baseRepoFn: func() (ghrepo.Interface, error) { return ghrepo.FromFullName("OWNER/REPO") }, }, httpStub: func(r *httpmock.Registry) { r.Register( httpmock.GraphQL(`query PullRequestByNumber\b`), httpmock.StringResponse(`{"data":{"repository":{ "pullRequest":{"number":13} }}}`)) }, wantPR: 13, wantRepo: "https://github.com/OWNER/REPO", }, { name: "URL argument", args: args{ selector: "https://example.org/OWNER/REPO/pull/13/files", baseRepoFn: nil, }, httpStub: func(r *httpmock.Registry) { r.Register( httpmock.GraphQL(`query PullRequest_fields\b`), httpmock.StringResponse(`{"data":{}}`)) r.Register( httpmock.GraphQL(`query PullRequest_fields2\b`), httpmock.StringResponse(`{"data":{}}`)) r.Register( httpmock.GraphQL(`query PullRequestByNumber\b`), httpmock.StringResponse(`{"data":{"repository":{ "pullRequest":{"number":13} }}}`)) }, wantPR: 13, wantRepo: "https://example.org/OWNER/REPO", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { reg := &httpmock.Registry{} if tt.httpStub != nil { tt.httpStub(reg) } httpClient := &http.Client{Transport: reg} pr, repo, err := PRFromArgs(api.NewClientFromHTTP(httpClient), tt.args.baseRepoFn, tt.args.branchFn, tt.args.remotesFn, tt.args.selector) if (err != nil) != tt.wantErr { t.Errorf("IssueFromArg() error = %v, wantErr %v", err, tt.wantErr) return } if pr.Number != tt.wantPR { t.Errorf("want issue #%d, got #%d", tt.wantPR, pr.Number) } repoURL := ghrepo.GenerateRepoURL(repo, "") if repoURL != tt.wantRepo { t.Errorf("want repo %s, got %s", tt.wantRepo, repoURL) } }) } }