From 870da79886e418e7c095247653440635edc0fd99 Mon Sep 17 00:00:00 2001 From: William Martin Date: Wed, 18 Dec 2024 13:36:53 +0100 Subject: [PATCH] Use smarter base repo funcs for secret commands --- pkg/cmd/secret/delete/delete.go | 25 ++- pkg/cmd/secret/delete/delete_test.go | 221 +++++++++---------- pkg/cmd/secret/list/list.go | 33 +-- pkg/cmd/secret/list/list_test.go | 271 +++++++++--------------- pkg/cmd/secret/set/set.go | 35 ++- pkg/cmd/secret/set/set_test.go | 240 +++++++++------------ pkg/cmd/secret/shared/base_repo.go | 75 +++++-- pkg/cmd/secret/shared/base_repo_test.go | 239 +++++++++++++++++++++ pkg/cmdutil/flags.go | 8 - 9 files changed, 647 insertions(+), 500 deletions(-) create mode 100644 pkg/cmd/secret/shared/base_repo_test.go diff --git a/pkg/cmd/secret/delete/delete.go b/pkg/cmd/secret/delete/delete.go index 70aab4be5..64b02c9d9 100644 --- a/pkg/cmd/secret/delete/delete.go +++ b/pkg/cmd/secret/delete/delete.go @@ -6,7 +6,6 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" - ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/cmd/secret/shared" @@ -20,15 +19,12 @@ type DeleteOptions struct { IO *iostreams.IOStreams Config func() (gh.Config, error) BaseRepo func() (ghrepo.Interface, error) - Remotes func() (ghContext.Remotes, error) SecretName string OrgName string EnvName string UserSecrets bool Application string - - HasRepoOverride bool } func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Command { @@ -36,7 +32,6 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co IO: f.IOStreams, Config: f.Config, HttpClient: f.HttpClient, - Remotes: f.Remotes, } cmd := &cobra.Command{ @@ -51,15 +46,24 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co `), Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override + // If the user specified a repo directly, then we're using the OverrideBaseRepoFunc set by EnableRepoOverride + // So there's no reason to use the specialised BaseRepoFunc that requires remote disambiguation. opts.BaseRepo = f.BaseRepo + if !cmd.Flags().Changed("repo") { + // If they haven't specified a repo directly, then we will wrap the BaseRepoFunc in one that error if + // there might be multiple valid remotes. + opts.BaseRepo = shared.RequireNoAmbiguityBaseRepoFunc(opts.BaseRepo, f.Remotes) + // But if we are able to prompt, then we will wrap that up in a BaseRepoFunc that can prompt the user to + // resolve the ambiguity. + if opts.IO.CanPrompt() { + opts.BaseRepo = shared.PromptWhenMultipleRemotesBaseRepoFunc(opts.BaseRepo, f.Prompter) + } + } if err := cmdutil.MutuallyExclusive("specify only one of `--org`, `--env`, or `--user`", opts.OrgName != "", opts.EnvName != "", opts.UserSecrets); err != nil { return err } - opts.HasRepoOverride = cmd.Flags().Changed("repo") - opts.SecretName = args[0] if runF != nil { @@ -110,11 +114,6 @@ func removeRun(opts *DeleteOptions) error { if err != nil { return err } - - err = shared.ValidateHasOnlyOneRemote(opts.HasRepoOverride, opts.Remotes) - if err != nil { - return err - } } cfg, err := opts.Config() diff --git a/pkg/cmd/secret/delete/delete_test.go b/pkg/cmd/secret/delete/delete_test.go index 45091a3b8..91d688a89 100644 --- a/pkg/cmd/secret/delete/delete_test.go +++ b/pkg/cmd/secret/delete/delete_test.go @@ -2,6 +2,7 @@ package delete import ( "bytes" + "io" "net/http" "testing" @@ -10,6 +11,8 @@ import ( "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/internal/prompter" + "github.com/cli/cli/v2/pkg/cmd/secret/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/iostreams" @@ -122,6 +125,108 @@ func TestNewCmdDelete(t *testing.T) { } } +func TestNewCmdDeleteBaseRepoFuncs(t *testing.T) { + remotes := ghContext.Remotes{ + &ghContext.Remote{ + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("owner", "fork"), + }, + &ghContext.Remote{ + Remote: &git.Remote{ + Name: "upstream", + }, + Repo: ghrepo.New("owner", "repo"), + }, + } + + tests := []struct { + name string + args string + prompterStubs func(*prompter.MockPrompter) + wantRepo ghrepo.Interface + wantErr error + }{ + { + name: "when there is a repo flag provided, the factory base repo func is used", + args: "SECRET_NAME --repo owner/repo", + wantRepo: ghrepo.New("owner", "repo"), + }, + { + name: "when there is no repo flag provided, and no prompting, the base func requiring no ambiguity is used", + args: "SECRET_NAME", + wantErr: shared.MultipleRemotesError{ + Remotes: remotes, + }, + }, + { + name: "when there is no repo flag provided, and can prompt, the base func resolving ambiguity is used", + args: "SECRET_NAME", + prompterStubs: func(pm *prompter.MockPrompter) { + pm.RegisterSelect( + "Select a base repo", + []string{"owner/fork", "owner/repo"}, + func(_, _ string, opts []string) (int, error) { + return prompter.IndexFor(opts, "owner/fork") + }, + ) + }, + wantRepo: ghrepo.New("owner", "fork"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + var pm *prompter.MockPrompter + if tt.prompterStubs != nil { + ios.SetStdinTTY(true) + ios.SetStdoutTTY(true) + ios.SetStderrTTY(true) + pm = prompter.NewMockPrompter(t) + tt.prompterStubs(pm) + } + + f := &cmdutil.Factory{ + IOStreams: ios, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.FromFullName("owner/repo") + }, + Prompter: pm, + Remotes: func() (ghContext.Remotes, error) { + return remotes, nil + }, + } + + argv, err := shlex.Split(tt.args) + assert.NoError(t, err) + + var gotOpts *DeleteOptions + cmd := NewCmdDelete(f, func(opts *DeleteOptions) error { + gotOpts = opts + return nil + }) + // Require to support --repo flag + cmdutil.EnableRepoOverride(cmd, f) + cmd.SetArgs(argv) + cmd.SetIn(&bytes.Buffer{}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + + _, err = cmd.ExecuteC() + require.NoError(t, err) + + baseRepo, err := gotOpts.BaseRepo() + if tt.wantErr != nil { + require.Equal(t, tt.wantErr, err) + return + } + require.True(t, ghrepo.IsSame(tt.wantRepo, baseRepo)) + }) + } +} + func Test_removeRun_repo(t *testing.T) { tests := []struct { name string @@ -353,119 +458,3 @@ func Test_removeRun_user(t *testing.T) { reg.Verify(t) } - -func Test_removeRun_remote_validation(t *testing.T) { - tests := []struct { - name string - opts *DeleteOptions - wantPath string - wantErr bool - errMsg string - }{ - { - name: "single repo detected", - opts: &DeleteOptions{ - Application: "actions", - SecretName: "cool_secret", - Remotes: func() (ghContext.Remotes, error) { - remote := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "origin", - }, - Repo: ghrepo.New("owner", "repo"), - } - - return ghContext.Remotes{ - remote, - }, nil - }}, - wantPath: "repos/owner/repo/actions/secrets/cool_secret", - }, - { - name: "multi repo detected", - opts: &DeleteOptions{ - Application: "actions", - SecretName: "cool_secret", - Remotes: func() (ghContext.Remotes, error) { - remote := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "origin", - }, - Repo: ghrepo.New("owner", "repo"), - } - remote2 := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "upstream", - }, - Repo: ghrepo.New("owner", "repo"), - } - - return ghContext.Remotes{ - remote, - remote2, - }, nil - }}, - wantErr: true, - errMsg: "multiple remotes detected [origin upstream]. please specify which repo to use by providing the -R or --repo argument", - }, - { - name: "multi repo detected - single repo given", - opts: &DeleteOptions{ - Application: "actions", - SecretName: "cool_secret", - HasRepoOverride: true, - Remotes: func() (ghContext.Remotes, error) { - remote := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "origin", - }, - Repo: ghrepo.New("owner", "repo"), - } - remote2 := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "upstream", - }, - Repo: ghrepo.New("owner", "repo"), - } - - return ghContext.Remotes{ - remote, - remote2, - }, nil - }}, - wantPath: "repos/owner/repo/actions/secrets/cool_secret", - }, - } - - for _, tt := range tests { - reg := &httpmock.Registry{} - - if tt.wantPath != "" { - reg.Register( - httpmock.REST("DELETE", tt.wantPath), - httpmock.StatusStringResponse(204, "No Content")) - } - - ios, _, _, _ := iostreams.Test() - - tt.opts.IO = ios - tt.opts.HttpClient = func() (*http.Client, error) { - return &http.Client{Transport: reg}, nil - } - tt.opts.Config = func() (gh.Config, error) { - return config.NewBlankConfig(), nil - } - tt.opts.BaseRepo = func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("owner/repo") - } - - err := removeRun(tt.opts) - if tt.wantErr { - assert.EqualError(t, err, tt.errMsg) - } else { - assert.NoError(t, err) - } - - reg.Verify(t) - } -} diff --git a/pkg/cmd/secret/list/list.go b/pkg/cmd/secret/list/list.go index 5c13234d0..49ad47aff 100644 --- a/pkg/cmd/secret/list/list.go +++ b/pkg/cmd/secret/list/list.go @@ -9,9 +9,9 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" - ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/internal/prompter" "github.com/cli/cli/v2/internal/tableprinter" "github.com/cli/cli/v2/pkg/cmd/secret/shared" "github.com/cli/cli/v2/pkg/cmdutil" @@ -24,16 +24,15 @@ type ListOptions struct { IO *iostreams.IOStreams Config func() (gh.Config, error) BaseRepo func() (ghrepo.Interface, error) - Remotes func() (ghContext.Remotes, error) - Now func() time.Time - Exporter cmdutil.Exporter + Prompter prompter.Prompter + + Now func() time.Time + Exporter cmdutil.Exporter OrgName string EnvName string UserSecrets bool Application string - - HasRepoOverride bool } var secretFields = []string{ @@ -51,8 +50,8 @@ func NewCmdList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Comman IO: f.IOStreams, Config: f.Config, HttpClient: f.HttpClient, - Remotes: f.Remotes, Now: time.Now, + Prompter: f.Prompter, } cmd := &cobra.Command{ @@ -68,15 +67,24 @@ func NewCmdList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Comman Aliases: []string{"ls"}, Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override + // If the user specified a repo directly, then we're using the OverrideBaseRepoFunc set by EnableRepoOverride + // So there's no reason to use the specialised BaseRepoFunc that requires remote disambiguation. opts.BaseRepo = f.BaseRepo + if !cmd.Flags().Changed("repo") { + // If they haven't specified a repo directly, then we will wrap the BaseRepoFunc in one that error if + // there might be multiple valid remotes. + opts.BaseRepo = shared.RequireNoAmbiguityBaseRepoFunc(opts.BaseRepo, f.Remotes) + // But if we are able to prompt, then we will wrap that up in a BaseRepoFunc that can prompt the user to + // resolve the ambiguity. + if opts.IO.CanPrompt() { + opts.BaseRepo = shared.PromptWhenMultipleRemotesBaseRepoFunc(opts.BaseRepo, f.Prompter) + } + } if err := cmdutil.MutuallyExclusive("specify only one of `--org`, `--env`, or `--user`", opts.OrgName != "", opts.EnvName != "", opts.UserSecrets); err != nil { return err } - opts.HasRepoOverride = cmd.Flags().Changed("repo") - if runF != nil { return runF(opts) } @@ -108,11 +116,6 @@ func listRun(opts *ListOptions) error { if err != nil { return err } - - err = shared.ValidateHasOnlyOneRemote(opts.HasRepoOverride, opts.Remotes) - if err != nil { - return err - } } secretEntity, err := shared.GetSecretEntity(orgName, envName, opts.UserSecrets) diff --git a/pkg/cmd/secret/list/list_test.go b/pkg/cmd/secret/list/list_test.go index a66b913ea..63a899b77 100644 --- a/pkg/cmd/secret/list/list_test.go +++ b/pkg/cmd/secret/list/list_test.go @@ -3,6 +3,7 @@ package list import ( "bytes" "fmt" + "io" "net/http" "net/url" "strings" @@ -14,6 +15,7 @@ import ( "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/internal/prompter" "github.com/cli/cli/v2/pkg/cmd/secret/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/httpmock" @@ -103,6 +105,108 @@ func Test_NewCmdList(t *testing.T) { } } +func TestNewCmdListBaseRepoFuncs(t *testing.T) { + remotes := ghContext.Remotes{ + &ghContext.Remote{ + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("owner", "fork"), + }, + &ghContext.Remote{ + Remote: &git.Remote{ + Name: "upstream", + }, + Repo: ghrepo.New("owner", "repo"), + }, + } + + tests := []struct { + name string + args string + prompterStubs func(*prompter.MockPrompter) + wantRepo ghrepo.Interface + wantErr error + }{ + { + name: "when there is a repo flag provided, the factory base repo func is used", + args: "--repo owner/repo", + wantRepo: ghrepo.New("owner", "repo"), + }, + { + name: "when there is no repo flag provided, and no prompting, the base func requiring no ambiguity is used", + args: "", + wantErr: shared.MultipleRemotesError{ + Remotes: remotes, + }, + }, + { + name: "when there is no repo flag provided, and can prompt, the base func resolving ambiguity is used", + args: "", + prompterStubs: func(pm *prompter.MockPrompter) { + pm.RegisterSelect( + "Select a base repo", + []string{"owner/fork", "owner/repo"}, + func(_, _ string, opts []string) (int, error) { + return prompter.IndexFor(opts, "owner/fork") + }, + ) + }, + wantRepo: ghrepo.New("owner", "fork"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + var pm *prompter.MockPrompter + if tt.prompterStubs != nil { + ios.SetStdinTTY(true) + ios.SetStdoutTTY(true) + ios.SetStderrTTY(true) + pm = prompter.NewMockPrompter(t) + tt.prompterStubs(pm) + } + + f := &cmdutil.Factory{ + IOStreams: ios, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.FromFullName("owner/repo") + }, + Prompter: pm, + Remotes: func() (ghContext.Remotes, error) { + return remotes, nil + }, + } + + argv, err := shlex.Split(tt.args) + assert.NoError(t, err) + + var gotOpts *ListOptions + cmd := NewCmdList(f, func(opts *ListOptions) error { + gotOpts = opts + return nil + }) + // Require to support --repo flag + cmdutil.EnableRepoOverride(cmd, f) + cmd.SetArgs(argv) + cmd.SetIn(&bytes.Buffer{}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + + _, err = cmd.ExecuteC() + require.NoError(t, err) + + baseRepo, err := gotOpts.BaseRepo() + if tt.wantErr != nil { + require.Equal(t, tt.wantErr, err) + return + } + require.True(t, ghrepo.IsSame(tt.wantRepo, baseRepo)) + }) + } +} + func Test_listRun(t *testing.T) { tests := []struct { name string @@ -444,173 +548,6 @@ func Test_listRun(t *testing.T) { } } -func Test_listRunRemoteValidation(t *testing.T) { - tests := []struct { - name string - tty bool - json bool - opts *ListOptions - wantOut []string - wantErr bool - errMsg string - }{ - { - name: "single repo detected", - tty: false, - opts: &ListOptions{ - Remotes: func() (ghContext.Remotes, error) { - remote := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "origin", - }, - Repo: ghrepo.New("owner", "repo"), - } - - return ghContext.Remotes{ - remote, - }, nil - }, - }, - wantOut: []string{ - "SECRET_ONE\t1988-10-11T00:00:00Z", - "SECRET_TWO\t2020-12-04T00:00:00Z", - "SECRET_THREE\t1975-11-30T00:00:00Z", - }, - }, - { - name: "multi repo detected", - tty: false, - opts: &ListOptions{ - Remotes: func() (ghContext.Remotes, error) { - remote := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "origin", - }, - Repo: ghrepo.New("owner", "repo"), - } - remote2 := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "upstream", - }, - Repo: ghrepo.New("owner", "repo"), - } - - return ghContext.Remotes{ - remote, - remote2, - }, nil - }, - }, - wantOut: []string{}, - wantErr: true, - errMsg: "multiple remotes detected [origin upstream]. please specify which repo to use by providing the -R or --repo argument", - }, - { - name: "multi repo detected - single repo given", - tty: false, - opts: &ListOptions{ - HasRepoOverride: true, - Remotes: func() (ghContext.Remotes, error) { - remote := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "origin", - }, - Repo: ghrepo.New("owner", "repo"), - } - remote2 := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "upstream", - }, - Repo: ghrepo.New("owner", "repo"), - } - - return ghContext.Remotes{ - remote, - remote2, - }, nil - }, - }, - wantOut: []string{ - "SECRET_ONE\t1988-10-11T00:00:00Z", - "SECRET_TWO\t2020-12-04T00:00:00Z", - "SECRET_THREE\t1975-11-30T00:00:00Z", - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - reg := &httpmock.Registry{} - reg.Verify(t) - - path := "repos/owner/repo/actions/secrets" - - t0, _ := time.Parse("2006-01-02", "1988-10-11") - t1, _ := time.Parse("2006-01-02", "2020-12-04") - t2, _ := time.Parse("2006-01-02", "1975-11-30") - payload := struct { - Secrets []Secret - }{ - Secrets: []Secret{ - { - Name: "SECRET_ONE", - UpdatedAt: t0, - }, - { - Name: "SECRET_TWO", - UpdatedAt: t1, - }, - { - Name: "SECRET_THREE", - UpdatedAt: t2, - }, - }, - } - - reg.Register(httpmock.REST("GET", path), httpmock.JSONResponse(payload)) - - ios, _, stdout, _ := iostreams.Test() - - ios.SetStdoutTTY(tt.tty) - - tt.opts.IO = ios - tt.opts.BaseRepo = func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("owner/repo") - } - tt.opts.HttpClient = func() (*http.Client, error) { - return &http.Client{Transport: reg}, nil - } - tt.opts.Config = func() (gh.Config, error) { - return config.NewBlankConfig(), nil - } - tt.opts.Now = func() time.Time { - t, _ := time.Parse(time.RFC822, "15 Mar 23 00:00 UTC") - return t - } - - if tt.json { - exporter := cmdutil.NewJSONExporter() - exporter.SetFields(secretFields) - tt.opts.Exporter = exporter - } - - err := listRun(tt.opts) - if tt.wantErr { - assert.EqualError(t, err, tt.errMsg) - - return - } - - assert.NoError(t, err) - - if len(tt.wantOut) > 1 { - expected := fmt.Sprintf("%s\n", strings.Join(tt.wantOut, "\n")) - assert.Equal(t, expected, stdout.String()) - } - }) - } -} - // Test_listRun_populatesNumSelectedReposIfRequired asserts that NumSelectedRepos // field is populated **only** when it's going to be presented in the output. Since // populating this field costs further API requests (one per secret), it's important diff --git a/pkg/cmd/secret/set/set.go b/pkg/cmd/secret/set/set.go index b4630e067..ba9b023a6 100644 --- a/pkg/cmd/secret/set/set.go +++ b/pkg/cmd/secret/set/set.go @@ -13,7 +13,6 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" - ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/cmd/secret/shared" @@ -30,7 +29,6 @@ type SetOptions struct { IO *iostreams.IOStreams Config func() (gh.Config, error) BaseRepo func() (ghrepo.Interface, error) - Remotes func() (ghContext.Remotes, error) Prompter prompter.Prompter RandomOverride func() io.Reader @@ -45,9 +43,6 @@ type SetOptions struct { RepositoryNames []string EnvFile string Application string - - HasRepoOverride bool - CanPrompt bool } func NewCmdSet(f *cmdutil.Factory, runF func(*SetOptions) error) *cobra.Command { @@ -55,7 +50,6 @@ func NewCmdSet(f *cmdutil.Factory, runF func(*SetOptions) error) *cobra.Command IO: f.IOStreams, Config: f.Config, HttpClient: f.HttpClient, - Remotes: f.Remotes, Prompter: f.Prompter, } @@ -110,8 +104,19 @@ func NewCmdSet(f *cmdutil.Factory, runF func(*SetOptions) error) *cobra.Command `), Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override + // If the user specified a repo directly, then we're using the OverrideBaseRepoFunc set by EnableRepoOverride + // So there's no reason to use the specialised BaseRepoFunc that requires remote disambiguation. opts.BaseRepo = f.BaseRepo + if !cmd.Flags().Changed("repo") { + // If they haven't specified a repo directly, then we will wrap the BaseRepoFunc in one that error if + // there might be multiple valid remotes. + opts.BaseRepo = shared.RequireNoAmbiguityBaseRepoFunc(opts.BaseRepo, f.Remotes) + // But if we are able to prompt, then we will wrap that up in a BaseRepoFunc that can prompt the user to + // resolve the ambiguity. + if opts.IO.CanPrompt() { + opts.BaseRepo = shared.PromptWhenMultipleRemotesBaseRepoFunc(opts.BaseRepo, f.Prompter) + } + } if err := cmdutil.MutuallyExclusive("specify only one of `--org`, `--env`, or `--user`", opts.OrgName != "", opts.EnvName != "", opts.UserSecrets); err != nil { return err @@ -151,9 +156,6 @@ func NewCmdSet(f *cmdutil.Factory, runF func(*SetOptions) error) *cobra.Command } } - opts.HasRepoOverride = cmd.Flags().Changed("repo") - opts.CanPrompt = opts.IO.CanPrompt() - if runF != nil { return runF(opts) } @@ -198,19 +200,6 @@ func setRun(opts *SetOptions) error { return err } - if err = shared.ValidateHasOnlyOneRemote(opts.HasRepoOverride, opts.Remotes); err != nil { - if !opts.CanPrompt { - return err - } - - selectedRepo, errSelectingRepo := shared.PromptForRepo(baseRepo, opts.Remotes, opts.Prompter) - if errSelectingRepo != nil { - return errSelectingRepo - } - - baseRepo = selectedRepo - } - host = baseRepo.RepoHost() } else { cfg, err := opts.Config() diff --git a/pkg/cmd/secret/set/set_test.go b/pkg/cmd/secret/set/set_test.go index 68c7624dd..63834521f 100644 --- a/pkg/cmd/secret/set/set_test.go +++ b/pkg/cmd/secret/set/set_test.go @@ -22,6 +22,7 @@ import ( "github.com/cli/cli/v2/pkg/iostreams" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewCmdSet(t *testing.T) { @@ -223,6 +224,108 @@ func TestNewCmdSet(t *testing.T) { } } +func TestNewCmdSetBaseRepoFuncs(t *testing.T) { + remotes := ghContext.Remotes{ + &ghContext.Remote{ + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("owner", "fork"), + }, + &ghContext.Remote{ + Remote: &git.Remote{ + Name: "upstream", + }, + Repo: ghrepo.New("owner", "repo"), + }, + } + + tests := []struct { + name string + args string + prompterStubs func(*prompter.MockPrompter) + wantRepo ghrepo.Interface + wantErr error + }{ + { + name: "when there is a repo flag provided, the factory base repo func is used", + args: "SECRET_NAME --repo owner/repo", + wantRepo: ghrepo.New("owner", "repo"), + }, + { + name: "when there is no repo flag provided, and no prompting, the base func requiring no ambiguity is used", + args: "SECRET_NAME", + wantErr: shared.MultipleRemotesError{ + Remotes: remotes, + }, + }, + { + name: "when there is no repo flag provided, and can prompt, the base func resolving ambiguity is used", + args: "SECRET_NAME", + prompterStubs: func(pm *prompter.MockPrompter) { + pm.RegisterSelect( + "Select a base repo", + []string{"owner/fork", "owner/repo"}, + func(_, _ string, opts []string) (int, error) { + return prompter.IndexFor(opts, "owner/fork") + }, + ) + }, + wantRepo: ghrepo.New("owner", "fork"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + var pm *prompter.MockPrompter + if tt.prompterStubs != nil { + ios.SetStdinTTY(true) + ios.SetStdoutTTY(true) + ios.SetStderrTTY(true) + pm = prompter.NewMockPrompter(t) + tt.prompterStubs(pm) + } + + f := &cmdutil.Factory{ + IOStreams: ios, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.FromFullName("owner/repo") + }, + Prompter: pm, + Remotes: func() (ghContext.Remotes, error) { + return remotes, nil + }, + } + + argv, err := shlex.Split(tt.args) + assert.NoError(t, err) + + var gotOpts *SetOptions + cmd := NewCmdSet(f, func(opts *SetOptions) error { + gotOpts = opts + return nil + }) + // Require to support --repo flag + cmdutil.EnableRepoOverride(cmd, f) + cmd.SetArgs(argv) + cmd.SetIn(&bytes.Buffer{}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + + _, err = cmd.ExecuteC() + require.NoError(t, err) + + baseRepo, err := gotOpts.BaseRepo() + if tt.wantErr != nil { + require.Equal(t, tt.wantErr, err) + return + } + require.True(t, ghrepo.IsSame(tt.wantRepo, baseRepo)) + }) + } +} + func Test_setRun_repo(t *testing.T) { tests := []struct { name string @@ -702,143 +805,6 @@ func Test_getSecretsFromOptions(t *testing.T) { } } -func Test_setRun_remote_validation(t *testing.T) { - tests := []struct { - name string - opts *SetOptions - wantApp string - wantErr bool - errMsg string - }{ - { - name: "single repo detected", - opts: &SetOptions{ - Application: "actions", - Remotes: func() (ghContext.Remotes, error) { - remote := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "origin", - }, - Repo: ghrepo.New("owner", "repo"), - } - - return ghContext.Remotes{ - remote, - }, nil - }, - }, - wantApp: "actions", - }, - { - name: "multi repo detected", - opts: &SetOptions{ - Application: "actions", - Remotes: func() (ghContext.Remotes, error) { - remote := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "origin", - }, - Repo: ghrepo.New("owner", "repo"), - } - remote2 := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "upstream", - }, - Repo: ghrepo.New("owner", "repo"), - } - - return ghContext.Remotes{ - remote, - remote2, - }, nil - }, - }, - wantErr: true, - errMsg: "multiple remotes detected [origin upstream]. please specify which repo to use by providing the -R or --repo argument", - }, - { - name: "multi repo detected - single repo given", - opts: &SetOptions{ - Application: "actions", - HasRepoOverride: true, - Remotes: func() (ghContext.Remotes, error) { - remote := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "origin", - }, - Repo: ghrepo.New("owner", "repo"), - } - remote2 := &ghContext.Remote{ - Remote: &git.Remote{ - Name: "upstream", - }, - Repo: ghrepo.New("owner", "repo"), - } - - return ghContext.Remotes{ - remote, - remote2, - }, nil - }, - }, - wantApp: "actions", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - reg := &httpmock.Registry{} - - if tt.wantApp != "" { - reg.Register(httpmock.REST("GET", fmt.Sprintf("repos/owner/repo/%s/secrets/public-key", tt.wantApp)), - httpmock.JSONResponse(PubKey{ID: "123", Key: "CDjXqf7AJBXWhMczcy+Fs7JlACEptgceysutztHaFQI="})) - - reg.Register(httpmock.REST("PUT", fmt.Sprintf("repos/owner/repo/%s/secrets/cool_secret", tt.wantApp)), - httpmock.StatusStringResponse(201, `{}`)) - } - - ios, _, _, _ := iostreams.Test() - - opts := &SetOptions{ - HttpClient: func() (*http.Client, error) { - return &http.Client{Transport: reg}, nil - }, - Config: func() (gh.Config, error) { return config.NewBlankConfig(), nil }, - BaseRepo: func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("owner/repo") - }, - IO: ios, - SecretName: "cool_secret", - Body: "a secret", - RandomOverride: fakeRandom, - Application: tt.opts.Application, - HasRepoOverride: tt.opts.HasRepoOverride, - Remotes: tt.opts.Remotes, - } - - err := setRun(opts) - if tt.wantErr { - assert.EqualError(t, err, tt.errMsg) - } else { - assert.NoError(t, err) - } - - reg.Verify(t) - - if tt.wantApp != "" && !tt.wantErr { - data, err := io.ReadAll(reg.Requests[1].Body) - assert.NoError(t, err) - - var payload SecretPayload - err = json.Unmarshal(data, &payload) - assert.NoError(t, err) - assert.Equal(t, payload.KeyID, "123") - assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=") - } - }) - } -} - func fakeRandom() io.Reader { return bytes.NewReader(bytes.Repeat([]byte{5}, 32)) } diff --git a/pkg/cmd/secret/shared/base_repo.go b/pkg/cmd/secret/shared/base_repo.go index db2dfc8f1..3420f6da1 100644 --- a/pkg/cmd/secret/shared/base_repo.go +++ b/pkg/cmd/secret/shared/base_repo.go @@ -1,47 +1,80 @@ package shared import ( - "fmt" + "errors" ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/prompter" ) -func ValidateHasOnlyOneRemote(hasRepoOverride bool, remotes func() (ghContext.Remotes, error)) error { - if !hasRepoOverride && remotes != nil { +type MultipleRemotesError struct { + Remotes ghContext.Remotes +} + +func (e MultipleRemotesError) Error() string { + return "multiple remotes detected. please specify which repo to use by providing the -R or --repo argument" +} + +type baseRepoFn func() (ghrepo.Interface, error) +type remotesFn func() (ghContext.Remotes, error) + +func PromptWhenMultipleRemotesBaseRepoFunc(baseRepoFn baseRepoFn, prompter prompter.Prompter) baseRepoFn { + return func() (ghrepo.Interface, error) { + baseRepo, err := baseRepoFn() + if err != nil { + var multipleRemotesError MultipleRemotesError + if !errors.As(err, &multipleRemotesError) { + return nil, err + } + + // prompt for the base repo + baseRepo, err = promptForRepo(baseRepo, multipleRemotesError.Remotes, prompter) + if err != nil { + return nil, err + } + } + + return baseRepo, nil + } +} + +// RequireNoAmbiguityBaseRepoFunc returns a function to resolve the base repo, ensuring that +// there was only one remote. +func RequireNoAmbiguityBaseRepoFunc(baseRepo baseRepoFn, remotes remotesFn) baseRepoFn { + return func() (ghrepo.Interface, error) { + // TODO: Is this really correct? Some remotes may not be in the same network. We probably need to resolve the + // network rather than looking at the remotes? remotes, err := remotes() if err != nil { - return err + return nil, err } if remotes.Len() > 1 { - return fmt.Errorf("multiple remotes detected %v. please specify which repo to use by providing the -R or --repo argument", remotes) + return nil, MultipleRemotesError{Remotes: remotes} } - } - return nil + return baseRepo() + } } -func PromptForRepo(baseRepo ghrepo.Interface, remotes func() (ghContext.Remotes, error), survey prompter.Prompter) (ghrepo.Interface, error) { +func promptForRepo(baseRepo ghrepo.Interface, remotes ghContext.Remotes, prompter prompter.Prompter) (ghrepo.Interface, error) { var defaultRepo string var remoteArray []string - if remotes, _ := remotes(); remotes != nil { - if defaultRemote, _ := remotes.ResolvedRemote(); defaultRemote != nil { - // this is a remote explicitly chosen via `repo set-default` - defaultRepo = ghrepo.FullName(defaultRemote) - } else if len(remotes) > 0 { - // as a fallback, just pick the first remote - defaultRepo = ghrepo.FullName(remotes[0]) - } - - for _, remote := range remotes { - remoteArray = append(remoteArray, ghrepo.FullName(remote)) - } + if defaultRemote, _ := remotes.ResolvedRemote(); defaultRemote != nil { + // this is a remote explicitly chosen via `repo set-default` + defaultRepo = ghrepo.FullName(defaultRemote) + } else if len(remotes) > 0 { + // as a fallback, just pick the first remote + defaultRepo = ghrepo.FullName(remotes[0]) } - baseRepoInput, errInput := survey.Select("Select a base repo", defaultRepo, remoteArray) + for _, remote := range remotes { + remoteArray = append(remoteArray, ghrepo.FullName(remote)) + } + + baseRepoInput, errInput := prompter.Select("Select a base repo", defaultRepo, remoteArray) if errInput != nil { return baseRepo, errInput } diff --git a/pkg/cmd/secret/shared/base_repo_test.go b/pkg/cmd/secret/shared/base_repo_test.go new file mode 100644 index 000000000..64b2a3d55 --- /dev/null +++ b/pkg/cmd/secret/shared/base_repo_test.go @@ -0,0 +1,239 @@ +package shared_test + +import ( + "errors" + "testing" + + ghContext "github.com/cli/cli/v2/context" + "github.com/cli/cli/v2/pkg/cmd/secret/shared" + "github.com/stretchr/testify/require" + + "github.com/cli/cli/v2/git" + "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/internal/prompter" +) + +func TestRequireNoAmbiguityBaseRepoFunc(t *testing.T) { + t.Parallel() + + t.Run("succeeds when there is only one remote", func(t *testing.T) { + t.Parallel() + + // Given there is only one remote + baseRepoFn := shared.RequireNoAmbiguityBaseRepoFunc(baseRepoStubFn, oneRemoteStubFn) + + // When fetching the base repo + baseRepo, err := baseRepoFn() + + // It succeeds and returns the inner base repo + require.NoError(t, err) + require.True(t, ghrepo.IsSame(ghrepo.New("owner", "repo"), baseRepo)) + }) + + t.Run("returns specific error when there are multiple remotes", func(t *testing.T) { + t.Parallel() + + // Given there are multiple remotes + baseRepoFn := shared.RequireNoAmbiguityBaseRepoFunc(baseRepoStubFn, twoRemotesStubFn) + + // When fetching the base repo + _, err := baseRepoFn() + + // It succeeds and returns the inner base repo + var multipleRemotesError shared.MultipleRemotesError + require.ErrorAs(t, err, &multipleRemotesError) + require.Equal(t, ghContext.Remotes{ + { + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("owner", "fork"), + }, + { + Remote: &git.Remote{ + Name: "upstream", + }, + Repo: ghrepo.New("owner", "repo"), + }, + }, multipleRemotesError.Remotes) + }) + + t.Run("when the remote fetching function fails, it returns the error", func(t *testing.T) { + t.Parallel() + + // Given the remote fetching function fails + baseRepoFn := shared.RequireNoAmbiguityBaseRepoFunc(baseRepoStubFn, errRemoteStubFn) + + // When fetching the base repo + _, err := baseRepoFn() + + // It returns the error + require.Equal(t, errors.New("test remote error"), err) + }) + + t.Run("when the wrapped base repo function fails, it returns the error", func(t *testing.T) { + t.Parallel() + + // Given the wrapped base repo function fails + baseRepoFn := shared.RequireNoAmbiguityBaseRepoFunc(errBaseRepoStubFn, oneRemoteStubFn) + + // When fetching the base repo + _, err := baseRepoFn() + + // It returns the error + require.Equal(t, errors.New("test base repo error"), err) + }) +} + +func TestPromptWhenMultipleRemotesBaseRepoFunc(t *testing.T) { + t.Parallel() + + t.Run("when there is no error from wrapped base repo func, then it succeeds without prompting", func(t *testing.T) { + t.Parallel() + + // Given the base repo function succeeds + baseRepoFn := shared.PromptWhenMultipleRemotesBaseRepoFunc(baseRepoStubFn, nil) + + // When fetching the base repo + baseRepo, err := baseRepoFn() + + // It succeeds and returns the inner base repo + require.NoError(t, err) + require.True(t, ghrepo.IsSame(ghrepo.New("owner", "repo"), baseRepo)) + }) + + t.Run("when the wrapped base repo func returns a specific error, then the prompter is used for disambiguation, with the resolved remote as the default", func(t *testing.T) { + t.Parallel() + + pm := prompter.NewMockPrompter(t) + pm.RegisterSelect( + "Select a base repo", + []string{"owner/fork", "owner/repo"}, + func(_, def string, opts []string) (int, error) { + t.Helper() + require.Equal(t, "owner/repo", def) + return prompter.IndexFor(opts, "owner/repo") + }, + ) + + // Given the wrapped base repo func returns a specific error + baseRepoFn := shared.PromptWhenMultipleRemotesBaseRepoFunc(errMultipleRemotesStubFn, pm) + + // When fetching the base repo + baseRepo, err := baseRepoFn() + + // It uses the prompter for disambiguation + require.NoError(t, err) + require.True(t, ghrepo.IsSame(ghrepo.New("owner", "repo"), baseRepo)) + }) + + t.Run("when the prompter returns an error, then it is returned", func(t *testing.T) { + t.Parallel() + + // Given the prompter returns an error + pm := prompter.NewMockPrompter(t) + pm.RegisterSelect( + "Select a base repo", + []string{"owner/fork", "owner/repo"}, + func(_, _ string, opts []string) (int, error) { + return 0, errors.New("test prompt error") + }, + ) + + // Given the wrapped base repo func returns a specific error + baseRepoFn := shared.PromptWhenMultipleRemotesBaseRepoFunc(errMultipleRemotesStubFn, pm) + + // When fetching the base repo + _, err := baseRepoFn() + + // It returns the error + require.Equal(t, errors.New("test prompt error"), err) + }) + + t.Run("when the wrapped base repo func returns a non-specific error, then it is returned", func(t *testing.T) { + t.Parallel() + + // Given the wrapped base repo func returns a non-specific error + baseRepoFn := shared.PromptWhenMultipleRemotesBaseRepoFunc(errBaseRepoStubFn, nil) + + // When fetching the base repo + _, err := baseRepoFn() + + // It returns the error + require.Equal(t, errors.New("test base repo error"), err) + }) +} + +func TestMultipleRemotesErrorMessage(t *testing.T) { + err := shared.MultipleRemotesError{} + require.EqualError(t, err, "multiple remotes detected. please specify which repo to use by providing the -R or --repo argument") +} + +func errMultipleRemotesStubFn() (ghrepo.Interface, error) { + remote1 := &ghContext.Remote{ + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("owner", "fork"), + } + + remote2 := &ghContext.Remote{ + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("owner", "repo"), + } + + return nil, shared.MultipleRemotesError{ + Remotes: ghContext.Remotes{ + remote1, + remote2, + }, + } +} + +func baseRepoStubFn() (ghrepo.Interface, error) { + return ghrepo.New("owner", "repo"), nil +} + +func oneRemoteStubFn() (ghContext.Remotes, error) { + remote := &ghContext.Remote{ + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("owner", "repo"), + } + + return ghContext.Remotes{ + remote, + }, nil +} + +func twoRemotesStubFn() (ghContext.Remotes, error) { + remote1 := &ghContext.Remote{ + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("owner", "fork"), + } + + remote2 := &ghContext.Remote{ + Remote: &git.Remote{ + Name: "upstream", + }, + Repo: ghrepo.New("owner", "repo"), + } + return ghContext.Remotes{ + remote1, + remote2, + }, nil +} + +func errRemoteStubFn() (ghContext.Remotes, error) { + return nil, errors.New("test remote error") +} + +func errBaseRepoStubFn() (ghrepo.Interface, error) { + return nil, errors.New("test base repo error") +} diff --git a/pkg/cmdutil/flags.go b/pkg/cmdutil/flags.go index fadbc5fbe..c0064099c 100644 --- a/pkg/cmdutil/flags.go +++ b/pkg/cmdutil/flags.go @@ -180,11 +180,3 @@ func isIncluded(value string, opts []string) bool { } return false } - -func CountSetFlags(flags *pflag.FlagSet) int { - count := 0 - flags.Visit(func(f *pflag.Flag) { - count++ - }) - return count -}