Use smarter base repo funcs for secret commands

This commit is contained in:
William Martin 2024-12-18 13:36:53 +01:00
parent 73244c010e
commit 870da79886
9 changed files with 647 additions and 500 deletions

View file

@ -6,7 +6,6 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/api"
ghContext "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmd/secret/shared"
@ -20,15 +19,12 @@ type DeleteOptions struct {
IO *iostreams.IOStreams
Config func() (gh.Config, error)
BaseRepo func() (ghrepo.Interface, error)
Remotes func() (ghContext.Remotes, error)
SecretName string
OrgName string
EnvName string
UserSecrets bool
Application string
HasRepoOverride bool
}
func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Command {
@ -36,7 +32,6 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co
IO: f.IOStreams,
Config: f.Config,
HttpClient: f.HttpClient,
Remotes: f.Remotes,
}
cmd := &cobra.Command{
@ -51,15 +46,24 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co
`),
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
// If the user specified a repo directly, then we're using the OverrideBaseRepoFunc set by EnableRepoOverride
// So there's no reason to use the specialised BaseRepoFunc that requires remote disambiguation.
opts.BaseRepo = f.BaseRepo
if !cmd.Flags().Changed("repo") {
// If they haven't specified a repo directly, then we will wrap the BaseRepoFunc in one that error if
// there might be multiple valid remotes.
opts.BaseRepo = shared.RequireNoAmbiguityBaseRepoFunc(opts.BaseRepo, f.Remotes)
// But if we are able to prompt, then we will wrap that up in a BaseRepoFunc that can prompt the user to
// resolve the ambiguity.
if opts.IO.CanPrompt() {
opts.BaseRepo = shared.PromptWhenMultipleRemotesBaseRepoFunc(opts.BaseRepo, f.Prompter)
}
}
if err := cmdutil.MutuallyExclusive("specify only one of `--org`, `--env`, or `--user`", opts.OrgName != "", opts.EnvName != "", opts.UserSecrets); err != nil {
return err
}
opts.HasRepoOverride = cmd.Flags().Changed("repo")
opts.SecretName = args[0]
if runF != nil {
@ -110,11 +114,6 @@ func removeRun(opts *DeleteOptions) error {
if err != nil {
return err
}
err = shared.ValidateHasOnlyOneRemote(opts.HasRepoOverride, opts.Remotes)
if err != nil {
return err
}
}
cfg, err := opts.Config()

View file

@ -2,6 +2,7 @@ package delete
import (
"bytes"
"io"
"net/http"
"testing"
@ -10,6 +11,8 @@ import (
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/pkg/cmd/secret/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/cli/cli/v2/pkg/iostreams"
@ -122,6 +125,108 @@ func TestNewCmdDelete(t *testing.T) {
}
}
func TestNewCmdDeleteBaseRepoFuncs(t *testing.T) {
remotes := ghContext.Remotes{
&ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "fork"),
},
&ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
},
}
tests := []struct {
name string
args string
prompterStubs func(*prompter.MockPrompter)
wantRepo ghrepo.Interface
wantErr error
}{
{
name: "when there is a repo flag provided, the factory base repo func is used",
args: "SECRET_NAME --repo owner/repo",
wantRepo: ghrepo.New("owner", "repo"),
},
{
name: "when there is no repo flag provided, and no prompting, the base func requiring no ambiguity is used",
args: "SECRET_NAME",
wantErr: shared.MultipleRemotesError{
Remotes: remotes,
},
},
{
name: "when there is no repo flag provided, and can prompt, the base func resolving ambiguity is used",
args: "SECRET_NAME",
prompterStubs: func(pm *prompter.MockPrompter) {
pm.RegisterSelect(
"Select a base repo",
[]string{"owner/fork", "owner/repo"},
func(_, _ string, opts []string) (int, error) {
return prompter.IndexFor(opts, "owner/fork")
},
)
},
wantRepo: ghrepo.New("owner", "fork"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ios, _, _, _ := iostreams.Test()
var pm *prompter.MockPrompter
if tt.prompterStubs != nil {
ios.SetStdinTTY(true)
ios.SetStdoutTTY(true)
ios.SetStderrTTY(true)
pm = prompter.NewMockPrompter(t)
tt.prompterStubs(pm)
}
f := &cmdutil.Factory{
IOStreams: ios,
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("owner/repo")
},
Prompter: pm,
Remotes: func() (ghContext.Remotes, error) {
return remotes, nil
},
}
argv, err := shlex.Split(tt.args)
assert.NoError(t, err)
var gotOpts *DeleteOptions
cmd := NewCmdDelete(f, func(opts *DeleteOptions) error {
gotOpts = opts
return nil
})
// Require to support --repo flag
cmdutil.EnableRepoOverride(cmd, f)
cmd.SetArgs(argv)
cmd.SetIn(&bytes.Buffer{})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
_, err = cmd.ExecuteC()
require.NoError(t, err)
baseRepo, err := gotOpts.BaseRepo()
if tt.wantErr != nil {
require.Equal(t, tt.wantErr, err)
return
}
require.True(t, ghrepo.IsSame(tt.wantRepo, baseRepo))
})
}
}
func Test_removeRun_repo(t *testing.T) {
tests := []struct {
name string
@ -353,119 +458,3 @@ func Test_removeRun_user(t *testing.T) {
reg.Verify(t)
}
func Test_removeRun_remote_validation(t *testing.T) {
tests := []struct {
name string
opts *DeleteOptions
wantPath string
wantErr bool
errMsg string
}{
{
name: "single repo detected",
opts: &DeleteOptions{
Application: "actions",
SecretName: "cool_secret",
Remotes: func() (ghContext.Remotes, error) {
remote := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "repo"),
}
return ghContext.Remotes{
remote,
}, nil
}},
wantPath: "repos/owner/repo/actions/secrets/cool_secret",
},
{
name: "multi repo detected",
opts: &DeleteOptions{
Application: "actions",
SecretName: "cool_secret",
Remotes: func() (ghContext.Remotes, error) {
remote := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "repo"),
}
remote2 := &ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
}
return ghContext.Remotes{
remote,
remote2,
}, nil
}},
wantErr: true,
errMsg: "multiple remotes detected [origin upstream]. please specify which repo to use by providing the -R or --repo argument",
},
{
name: "multi repo detected - single repo given",
opts: &DeleteOptions{
Application: "actions",
SecretName: "cool_secret",
HasRepoOverride: true,
Remotes: func() (ghContext.Remotes, error) {
remote := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "repo"),
}
remote2 := &ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
}
return ghContext.Remotes{
remote,
remote2,
}, nil
}},
wantPath: "repos/owner/repo/actions/secrets/cool_secret",
},
}
for _, tt := range tests {
reg := &httpmock.Registry{}
if tt.wantPath != "" {
reg.Register(
httpmock.REST("DELETE", tt.wantPath),
httpmock.StatusStringResponse(204, "No Content"))
}
ios, _, _, _ := iostreams.Test()
tt.opts.IO = ios
tt.opts.HttpClient = func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
}
tt.opts.Config = func() (gh.Config, error) {
return config.NewBlankConfig(), nil
}
tt.opts.BaseRepo = func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("owner/repo")
}
err := removeRun(tt.opts)
if tt.wantErr {
assert.EqualError(t, err, tt.errMsg)
} else {
assert.NoError(t, err)
}
reg.Verify(t)
}
}

View file

@ -9,9 +9,9 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/api"
ghContext "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/internal/tableprinter"
"github.com/cli/cli/v2/pkg/cmd/secret/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
@ -24,16 +24,15 @@ type ListOptions struct {
IO *iostreams.IOStreams
Config func() (gh.Config, error)
BaseRepo func() (ghrepo.Interface, error)
Remotes func() (ghContext.Remotes, error)
Now func() time.Time
Exporter cmdutil.Exporter
Prompter prompter.Prompter
Now func() time.Time
Exporter cmdutil.Exporter
OrgName string
EnvName string
UserSecrets bool
Application string
HasRepoOverride bool
}
var secretFields = []string{
@ -51,8 +50,8 @@ func NewCmdList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Comman
IO: f.IOStreams,
Config: f.Config,
HttpClient: f.HttpClient,
Remotes: f.Remotes,
Now: time.Now,
Prompter: f.Prompter,
}
cmd := &cobra.Command{
@ -68,15 +67,24 @@ func NewCmdList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Comman
Aliases: []string{"ls"},
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
// If the user specified a repo directly, then we're using the OverrideBaseRepoFunc set by EnableRepoOverride
// So there's no reason to use the specialised BaseRepoFunc that requires remote disambiguation.
opts.BaseRepo = f.BaseRepo
if !cmd.Flags().Changed("repo") {
// If they haven't specified a repo directly, then we will wrap the BaseRepoFunc in one that error if
// there might be multiple valid remotes.
opts.BaseRepo = shared.RequireNoAmbiguityBaseRepoFunc(opts.BaseRepo, f.Remotes)
// But if we are able to prompt, then we will wrap that up in a BaseRepoFunc that can prompt the user to
// resolve the ambiguity.
if opts.IO.CanPrompt() {
opts.BaseRepo = shared.PromptWhenMultipleRemotesBaseRepoFunc(opts.BaseRepo, f.Prompter)
}
}
if err := cmdutil.MutuallyExclusive("specify only one of `--org`, `--env`, or `--user`", opts.OrgName != "", opts.EnvName != "", opts.UserSecrets); err != nil {
return err
}
opts.HasRepoOverride = cmd.Flags().Changed("repo")
if runF != nil {
return runF(opts)
}
@ -108,11 +116,6 @@ func listRun(opts *ListOptions) error {
if err != nil {
return err
}
err = shared.ValidateHasOnlyOneRemote(opts.HasRepoOverride, opts.Remotes)
if err != nil {
return err
}
}
secretEntity, err := shared.GetSecretEntity(orgName, envName, opts.UserSecrets)

View file

@ -3,6 +3,7 @@ package list
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"strings"
@ -14,6 +15,7 @@ import (
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/pkg/cmd/secret/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/httpmock"
@ -103,6 +105,108 @@ func Test_NewCmdList(t *testing.T) {
}
}
func TestNewCmdListBaseRepoFuncs(t *testing.T) {
remotes := ghContext.Remotes{
&ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "fork"),
},
&ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
},
}
tests := []struct {
name string
args string
prompterStubs func(*prompter.MockPrompter)
wantRepo ghrepo.Interface
wantErr error
}{
{
name: "when there is a repo flag provided, the factory base repo func is used",
args: "--repo owner/repo",
wantRepo: ghrepo.New("owner", "repo"),
},
{
name: "when there is no repo flag provided, and no prompting, the base func requiring no ambiguity is used",
args: "",
wantErr: shared.MultipleRemotesError{
Remotes: remotes,
},
},
{
name: "when there is no repo flag provided, and can prompt, the base func resolving ambiguity is used",
args: "",
prompterStubs: func(pm *prompter.MockPrompter) {
pm.RegisterSelect(
"Select a base repo",
[]string{"owner/fork", "owner/repo"},
func(_, _ string, opts []string) (int, error) {
return prompter.IndexFor(opts, "owner/fork")
},
)
},
wantRepo: ghrepo.New("owner", "fork"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ios, _, _, _ := iostreams.Test()
var pm *prompter.MockPrompter
if tt.prompterStubs != nil {
ios.SetStdinTTY(true)
ios.SetStdoutTTY(true)
ios.SetStderrTTY(true)
pm = prompter.NewMockPrompter(t)
tt.prompterStubs(pm)
}
f := &cmdutil.Factory{
IOStreams: ios,
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("owner/repo")
},
Prompter: pm,
Remotes: func() (ghContext.Remotes, error) {
return remotes, nil
},
}
argv, err := shlex.Split(tt.args)
assert.NoError(t, err)
var gotOpts *ListOptions
cmd := NewCmdList(f, func(opts *ListOptions) error {
gotOpts = opts
return nil
})
// Require to support --repo flag
cmdutil.EnableRepoOverride(cmd, f)
cmd.SetArgs(argv)
cmd.SetIn(&bytes.Buffer{})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
_, err = cmd.ExecuteC()
require.NoError(t, err)
baseRepo, err := gotOpts.BaseRepo()
if tt.wantErr != nil {
require.Equal(t, tt.wantErr, err)
return
}
require.True(t, ghrepo.IsSame(tt.wantRepo, baseRepo))
})
}
}
func Test_listRun(t *testing.T) {
tests := []struct {
name string
@ -444,173 +548,6 @@ func Test_listRun(t *testing.T) {
}
}
func Test_listRunRemoteValidation(t *testing.T) {
tests := []struct {
name string
tty bool
json bool
opts *ListOptions
wantOut []string
wantErr bool
errMsg string
}{
{
name: "single repo detected",
tty: false,
opts: &ListOptions{
Remotes: func() (ghContext.Remotes, error) {
remote := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "repo"),
}
return ghContext.Remotes{
remote,
}, nil
},
},
wantOut: []string{
"SECRET_ONE\t1988-10-11T00:00:00Z",
"SECRET_TWO\t2020-12-04T00:00:00Z",
"SECRET_THREE\t1975-11-30T00:00:00Z",
},
},
{
name: "multi repo detected",
tty: false,
opts: &ListOptions{
Remotes: func() (ghContext.Remotes, error) {
remote := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "repo"),
}
remote2 := &ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
}
return ghContext.Remotes{
remote,
remote2,
}, nil
},
},
wantOut: []string{},
wantErr: true,
errMsg: "multiple remotes detected [origin upstream]. please specify which repo to use by providing the -R or --repo argument",
},
{
name: "multi repo detected - single repo given",
tty: false,
opts: &ListOptions{
HasRepoOverride: true,
Remotes: func() (ghContext.Remotes, error) {
remote := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "repo"),
}
remote2 := &ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
}
return ghContext.Remotes{
remote,
remote2,
}, nil
},
},
wantOut: []string{
"SECRET_ONE\t1988-10-11T00:00:00Z",
"SECRET_TWO\t2020-12-04T00:00:00Z",
"SECRET_THREE\t1975-11-30T00:00:00Z",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reg := &httpmock.Registry{}
reg.Verify(t)
path := "repos/owner/repo/actions/secrets"
t0, _ := time.Parse("2006-01-02", "1988-10-11")
t1, _ := time.Parse("2006-01-02", "2020-12-04")
t2, _ := time.Parse("2006-01-02", "1975-11-30")
payload := struct {
Secrets []Secret
}{
Secrets: []Secret{
{
Name: "SECRET_ONE",
UpdatedAt: t0,
},
{
Name: "SECRET_TWO",
UpdatedAt: t1,
},
{
Name: "SECRET_THREE",
UpdatedAt: t2,
},
},
}
reg.Register(httpmock.REST("GET", path), httpmock.JSONResponse(payload))
ios, _, stdout, _ := iostreams.Test()
ios.SetStdoutTTY(tt.tty)
tt.opts.IO = ios
tt.opts.BaseRepo = func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("owner/repo")
}
tt.opts.HttpClient = func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
}
tt.opts.Config = func() (gh.Config, error) {
return config.NewBlankConfig(), nil
}
tt.opts.Now = func() time.Time {
t, _ := time.Parse(time.RFC822, "15 Mar 23 00:00 UTC")
return t
}
if tt.json {
exporter := cmdutil.NewJSONExporter()
exporter.SetFields(secretFields)
tt.opts.Exporter = exporter
}
err := listRun(tt.opts)
if tt.wantErr {
assert.EqualError(t, err, tt.errMsg)
return
}
assert.NoError(t, err)
if len(tt.wantOut) > 1 {
expected := fmt.Sprintf("%s\n", strings.Join(tt.wantOut, "\n"))
assert.Equal(t, expected, stdout.String())
}
})
}
}
// Test_listRun_populatesNumSelectedReposIfRequired asserts that NumSelectedRepos
// field is populated **only** when it's going to be presented in the output. Since
// populating this field costs further API requests (one per secret), it's important

View file

@ -13,7 +13,6 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/api"
ghContext "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmd/secret/shared"
@ -30,7 +29,6 @@ type SetOptions struct {
IO *iostreams.IOStreams
Config func() (gh.Config, error)
BaseRepo func() (ghrepo.Interface, error)
Remotes func() (ghContext.Remotes, error)
Prompter prompter.Prompter
RandomOverride func() io.Reader
@ -45,9 +43,6 @@ type SetOptions struct {
RepositoryNames []string
EnvFile string
Application string
HasRepoOverride bool
CanPrompt bool
}
func NewCmdSet(f *cmdutil.Factory, runF func(*SetOptions) error) *cobra.Command {
@ -55,7 +50,6 @@ func NewCmdSet(f *cmdutil.Factory, runF func(*SetOptions) error) *cobra.Command
IO: f.IOStreams,
Config: f.Config,
HttpClient: f.HttpClient,
Remotes: f.Remotes,
Prompter: f.Prompter,
}
@ -110,8 +104,19 @@ func NewCmdSet(f *cmdutil.Factory, runF func(*SetOptions) error) *cobra.Command
`),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// support `-R, --repo` override
// If the user specified a repo directly, then we're using the OverrideBaseRepoFunc set by EnableRepoOverride
// So there's no reason to use the specialised BaseRepoFunc that requires remote disambiguation.
opts.BaseRepo = f.BaseRepo
if !cmd.Flags().Changed("repo") {
// If they haven't specified a repo directly, then we will wrap the BaseRepoFunc in one that error if
// there might be multiple valid remotes.
opts.BaseRepo = shared.RequireNoAmbiguityBaseRepoFunc(opts.BaseRepo, f.Remotes)
// But if we are able to prompt, then we will wrap that up in a BaseRepoFunc that can prompt the user to
// resolve the ambiguity.
if opts.IO.CanPrompt() {
opts.BaseRepo = shared.PromptWhenMultipleRemotesBaseRepoFunc(opts.BaseRepo, f.Prompter)
}
}
if err := cmdutil.MutuallyExclusive("specify only one of `--org`, `--env`, or `--user`", opts.OrgName != "", opts.EnvName != "", opts.UserSecrets); err != nil {
return err
@ -151,9 +156,6 @@ func NewCmdSet(f *cmdutil.Factory, runF func(*SetOptions) error) *cobra.Command
}
}
opts.HasRepoOverride = cmd.Flags().Changed("repo")
opts.CanPrompt = opts.IO.CanPrompt()
if runF != nil {
return runF(opts)
}
@ -198,19 +200,6 @@ func setRun(opts *SetOptions) error {
return err
}
if err = shared.ValidateHasOnlyOneRemote(opts.HasRepoOverride, opts.Remotes); err != nil {
if !opts.CanPrompt {
return err
}
selectedRepo, errSelectingRepo := shared.PromptForRepo(baseRepo, opts.Remotes, opts.Prompter)
if errSelectingRepo != nil {
return errSelectingRepo
}
baseRepo = selectedRepo
}
host = baseRepo.RepoHost()
} else {
cfg, err := opts.Config()

View file

@ -22,6 +22,7 @@ import (
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/google/shlex"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewCmdSet(t *testing.T) {
@ -223,6 +224,108 @@ func TestNewCmdSet(t *testing.T) {
}
}
func TestNewCmdSetBaseRepoFuncs(t *testing.T) {
remotes := ghContext.Remotes{
&ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "fork"),
},
&ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
},
}
tests := []struct {
name string
args string
prompterStubs func(*prompter.MockPrompter)
wantRepo ghrepo.Interface
wantErr error
}{
{
name: "when there is a repo flag provided, the factory base repo func is used",
args: "SECRET_NAME --repo owner/repo",
wantRepo: ghrepo.New("owner", "repo"),
},
{
name: "when there is no repo flag provided, and no prompting, the base func requiring no ambiguity is used",
args: "SECRET_NAME",
wantErr: shared.MultipleRemotesError{
Remotes: remotes,
},
},
{
name: "when there is no repo flag provided, and can prompt, the base func resolving ambiguity is used",
args: "SECRET_NAME",
prompterStubs: func(pm *prompter.MockPrompter) {
pm.RegisterSelect(
"Select a base repo",
[]string{"owner/fork", "owner/repo"},
func(_, _ string, opts []string) (int, error) {
return prompter.IndexFor(opts, "owner/fork")
},
)
},
wantRepo: ghrepo.New("owner", "fork"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ios, _, _, _ := iostreams.Test()
var pm *prompter.MockPrompter
if tt.prompterStubs != nil {
ios.SetStdinTTY(true)
ios.SetStdoutTTY(true)
ios.SetStderrTTY(true)
pm = prompter.NewMockPrompter(t)
tt.prompterStubs(pm)
}
f := &cmdutil.Factory{
IOStreams: ios,
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("owner/repo")
},
Prompter: pm,
Remotes: func() (ghContext.Remotes, error) {
return remotes, nil
},
}
argv, err := shlex.Split(tt.args)
assert.NoError(t, err)
var gotOpts *SetOptions
cmd := NewCmdSet(f, func(opts *SetOptions) error {
gotOpts = opts
return nil
})
// Require to support --repo flag
cmdutil.EnableRepoOverride(cmd, f)
cmd.SetArgs(argv)
cmd.SetIn(&bytes.Buffer{})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
_, err = cmd.ExecuteC()
require.NoError(t, err)
baseRepo, err := gotOpts.BaseRepo()
if tt.wantErr != nil {
require.Equal(t, tt.wantErr, err)
return
}
require.True(t, ghrepo.IsSame(tt.wantRepo, baseRepo))
})
}
}
func Test_setRun_repo(t *testing.T) {
tests := []struct {
name string
@ -702,143 +805,6 @@ func Test_getSecretsFromOptions(t *testing.T) {
}
}
func Test_setRun_remote_validation(t *testing.T) {
tests := []struct {
name string
opts *SetOptions
wantApp string
wantErr bool
errMsg string
}{
{
name: "single repo detected",
opts: &SetOptions{
Application: "actions",
Remotes: func() (ghContext.Remotes, error) {
remote := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "repo"),
}
return ghContext.Remotes{
remote,
}, nil
},
},
wantApp: "actions",
},
{
name: "multi repo detected",
opts: &SetOptions{
Application: "actions",
Remotes: func() (ghContext.Remotes, error) {
remote := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "repo"),
}
remote2 := &ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
}
return ghContext.Remotes{
remote,
remote2,
}, nil
},
},
wantErr: true,
errMsg: "multiple remotes detected [origin upstream]. please specify which repo to use by providing the -R or --repo argument",
},
{
name: "multi repo detected - single repo given",
opts: &SetOptions{
Application: "actions",
HasRepoOverride: true,
Remotes: func() (ghContext.Remotes, error) {
remote := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "repo"),
}
remote2 := &ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
}
return ghContext.Remotes{
remote,
remote2,
}, nil
},
},
wantApp: "actions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reg := &httpmock.Registry{}
if tt.wantApp != "" {
reg.Register(httpmock.REST("GET", fmt.Sprintf("repos/owner/repo/%s/secrets/public-key", tt.wantApp)),
httpmock.JSONResponse(PubKey{ID: "123", Key: "CDjXqf7AJBXWhMczcy+Fs7JlACEptgceysutztHaFQI="}))
reg.Register(httpmock.REST("PUT", fmt.Sprintf("repos/owner/repo/%s/secrets/cool_secret", tt.wantApp)),
httpmock.StatusStringResponse(201, `{}`))
}
ios, _, _, _ := iostreams.Test()
opts := &SetOptions{
HttpClient: func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
},
Config: func() (gh.Config, error) { return config.NewBlankConfig(), nil },
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.FromFullName("owner/repo")
},
IO: ios,
SecretName: "cool_secret",
Body: "a secret",
RandomOverride: fakeRandom,
Application: tt.opts.Application,
HasRepoOverride: tt.opts.HasRepoOverride,
Remotes: tt.opts.Remotes,
}
err := setRun(opts)
if tt.wantErr {
assert.EqualError(t, err, tt.errMsg)
} else {
assert.NoError(t, err)
}
reg.Verify(t)
if tt.wantApp != "" && !tt.wantErr {
data, err := io.ReadAll(reg.Requests[1].Body)
assert.NoError(t, err)
var payload SecretPayload
err = json.Unmarshal(data, &payload)
assert.NoError(t, err)
assert.Equal(t, payload.KeyID, "123")
assert.Equal(t, payload.EncryptedValue, "UKYUCbHd0DJemxa3AOcZ6XcsBwALG9d4bpB8ZT0gSV39vl3BHiGSgj8zJapDxgB2BwqNqRhpjC4=")
}
})
}
}
func fakeRandom() io.Reader {
return bytes.NewReader(bytes.Repeat([]byte{5}, 32))
}

View file

@ -1,47 +1,80 @@
package shared
import (
"fmt"
"errors"
ghContext "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/prompter"
)
func ValidateHasOnlyOneRemote(hasRepoOverride bool, remotes func() (ghContext.Remotes, error)) error {
if !hasRepoOverride && remotes != nil {
type MultipleRemotesError struct {
Remotes ghContext.Remotes
}
func (e MultipleRemotesError) Error() string {
return "multiple remotes detected. please specify which repo to use by providing the -R or --repo argument"
}
type baseRepoFn func() (ghrepo.Interface, error)
type remotesFn func() (ghContext.Remotes, error)
func PromptWhenMultipleRemotesBaseRepoFunc(baseRepoFn baseRepoFn, prompter prompter.Prompter) baseRepoFn {
return func() (ghrepo.Interface, error) {
baseRepo, err := baseRepoFn()
if err != nil {
var multipleRemotesError MultipleRemotesError
if !errors.As(err, &multipleRemotesError) {
return nil, err
}
// prompt for the base repo
baseRepo, err = promptForRepo(baseRepo, multipleRemotesError.Remotes, prompter)
if err != nil {
return nil, err
}
}
return baseRepo, nil
}
}
// RequireNoAmbiguityBaseRepoFunc returns a function to resolve the base repo, ensuring that
// there was only one remote.
func RequireNoAmbiguityBaseRepoFunc(baseRepo baseRepoFn, remotes remotesFn) baseRepoFn {
return func() (ghrepo.Interface, error) {
// TODO: Is this really correct? Some remotes may not be in the same network. We probably need to resolve the
// network rather than looking at the remotes?
remotes, err := remotes()
if err != nil {
return err
return nil, err
}
if remotes.Len() > 1 {
return fmt.Errorf("multiple remotes detected %v. please specify which repo to use by providing the -R or --repo argument", remotes)
return nil, MultipleRemotesError{Remotes: remotes}
}
}
return nil
return baseRepo()
}
}
func PromptForRepo(baseRepo ghrepo.Interface, remotes func() (ghContext.Remotes, error), survey prompter.Prompter) (ghrepo.Interface, error) {
func promptForRepo(baseRepo ghrepo.Interface, remotes ghContext.Remotes, prompter prompter.Prompter) (ghrepo.Interface, error) {
var defaultRepo string
var remoteArray []string
if remotes, _ := remotes(); remotes != nil {
if defaultRemote, _ := remotes.ResolvedRemote(); defaultRemote != nil {
// this is a remote explicitly chosen via `repo set-default`
defaultRepo = ghrepo.FullName(defaultRemote)
} else if len(remotes) > 0 {
// as a fallback, just pick the first remote
defaultRepo = ghrepo.FullName(remotes[0])
}
for _, remote := range remotes {
remoteArray = append(remoteArray, ghrepo.FullName(remote))
}
if defaultRemote, _ := remotes.ResolvedRemote(); defaultRemote != nil {
// this is a remote explicitly chosen via `repo set-default`
defaultRepo = ghrepo.FullName(defaultRemote)
} else if len(remotes) > 0 {
// as a fallback, just pick the first remote
defaultRepo = ghrepo.FullName(remotes[0])
}
baseRepoInput, errInput := survey.Select("Select a base repo", defaultRepo, remoteArray)
for _, remote := range remotes {
remoteArray = append(remoteArray, ghrepo.FullName(remote))
}
baseRepoInput, errInput := prompter.Select("Select a base repo", defaultRepo, remoteArray)
if errInput != nil {
return baseRepo, errInput
}

View file

@ -0,0 +1,239 @@
package shared_test
import (
"errors"
"testing"
ghContext "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/pkg/cmd/secret/shared"
"github.com/stretchr/testify/require"
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/prompter"
)
func TestRequireNoAmbiguityBaseRepoFunc(t *testing.T) {
t.Parallel()
t.Run("succeeds when there is only one remote", func(t *testing.T) {
t.Parallel()
// Given there is only one remote
baseRepoFn := shared.RequireNoAmbiguityBaseRepoFunc(baseRepoStubFn, oneRemoteStubFn)
// When fetching the base repo
baseRepo, err := baseRepoFn()
// It succeeds and returns the inner base repo
require.NoError(t, err)
require.True(t, ghrepo.IsSame(ghrepo.New("owner", "repo"), baseRepo))
})
t.Run("returns specific error when there are multiple remotes", func(t *testing.T) {
t.Parallel()
// Given there are multiple remotes
baseRepoFn := shared.RequireNoAmbiguityBaseRepoFunc(baseRepoStubFn, twoRemotesStubFn)
// When fetching the base repo
_, err := baseRepoFn()
// It succeeds and returns the inner base repo
var multipleRemotesError shared.MultipleRemotesError
require.ErrorAs(t, err, &multipleRemotesError)
require.Equal(t, ghContext.Remotes{
{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "fork"),
},
{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
},
}, multipleRemotesError.Remotes)
})
t.Run("when the remote fetching function fails, it returns the error", func(t *testing.T) {
t.Parallel()
// Given the remote fetching function fails
baseRepoFn := shared.RequireNoAmbiguityBaseRepoFunc(baseRepoStubFn, errRemoteStubFn)
// When fetching the base repo
_, err := baseRepoFn()
// It returns the error
require.Equal(t, errors.New("test remote error"), err)
})
t.Run("when the wrapped base repo function fails, it returns the error", func(t *testing.T) {
t.Parallel()
// Given the wrapped base repo function fails
baseRepoFn := shared.RequireNoAmbiguityBaseRepoFunc(errBaseRepoStubFn, oneRemoteStubFn)
// When fetching the base repo
_, err := baseRepoFn()
// It returns the error
require.Equal(t, errors.New("test base repo error"), err)
})
}
func TestPromptWhenMultipleRemotesBaseRepoFunc(t *testing.T) {
t.Parallel()
t.Run("when there is no error from wrapped base repo func, then it succeeds without prompting", func(t *testing.T) {
t.Parallel()
// Given the base repo function succeeds
baseRepoFn := shared.PromptWhenMultipleRemotesBaseRepoFunc(baseRepoStubFn, nil)
// When fetching the base repo
baseRepo, err := baseRepoFn()
// It succeeds and returns the inner base repo
require.NoError(t, err)
require.True(t, ghrepo.IsSame(ghrepo.New("owner", "repo"), baseRepo))
})
t.Run("when the wrapped base repo func returns a specific error, then the prompter is used for disambiguation, with the resolved remote as the default", func(t *testing.T) {
t.Parallel()
pm := prompter.NewMockPrompter(t)
pm.RegisterSelect(
"Select a base repo",
[]string{"owner/fork", "owner/repo"},
func(_, def string, opts []string) (int, error) {
t.Helper()
require.Equal(t, "owner/repo", def)
return prompter.IndexFor(opts, "owner/repo")
},
)
// Given the wrapped base repo func returns a specific error
baseRepoFn := shared.PromptWhenMultipleRemotesBaseRepoFunc(errMultipleRemotesStubFn, pm)
// When fetching the base repo
baseRepo, err := baseRepoFn()
// It uses the prompter for disambiguation
require.NoError(t, err)
require.True(t, ghrepo.IsSame(ghrepo.New("owner", "repo"), baseRepo))
})
t.Run("when the prompter returns an error, then it is returned", func(t *testing.T) {
t.Parallel()
// Given the prompter returns an error
pm := prompter.NewMockPrompter(t)
pm.RegisterSelect(
"Select a base repo",
[]string{"owner/fork", "owner/repo"},
func(_, _ string, opts []string) (int, error) {
return 0, errors.New("test prompt error")
},
)
// Given the wrapped base repo func returns a specific error
baseRepoFn := shared.PromptWhenMultipleRemotesBaseRepoFunc(errMultipleRemotesStubFn, pm)
// When fetching the base repo
_, err := baseRepoFn()
// It returns the error
require.Equal(t, errors.New("test prompt error"), err)
})
t.Run("when the wrapped base repo func returns a non-specific error, then it is returned", func(t *testing.T) {
t.Parallel()
// Given the wrapped base repo func returns a non-specific error
baseRepoFn := shared.PromptWhenMultipleRemotesBaseRepoFunc(errBaseRepoStubFn, nil)
// When fetching the base repo
_, err := baseRepoFn()
// It returns the error
require.Equal(t, errors.New("test base repo error"), err)
})
}
func TestMultipleRemotesErrorMessage(t *testing.T) {
err := shared.MultipleRemotesError{}
require.EqualError(t, err, "multiple remotes detected. please specify which repo to use by providing the -R or --repo argument")
}
func errMultipleRemotesStubFn() (ghrepo.Interface, error) {
remote1 := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "fork"),
}
remote2 := &ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
Resolved: "base",
},
Repo: ghrepo.New("owner", "repo"),
}
return nil, shared.MultipleRemotesError{
Remotes: ghContext.Remotes{
remote1,
remote2,
},
}
}
func baseRepoStubFn() (ghrepo.Interface, error) {
return ghrepo.New("owner", "repo"), nil
}
func oneRemoteStubFn() (ghContext.Remotes, error) {
remote := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "repo"),
}
return ghContext.Remotes{
remote,
}, nil
}
func twoRemotesStubFn() (ghContext.Remotes, error) {
remote1 := &ghContext.Remote{
Remote: &git.Remote{
Name: "origin",
},
Repo: ghrepo.New("owner", "fork"),
}
remote2 := &ghContext.Remote{
Remote: &git.Remote{
Name: "upstream",
},
Repo: ghrepo.New("owner", "repo"),
}
return ghContext.Remotes{
remote1,
remote2,
}, nil
}
func errRemoteStubFn() (ghContext.Remotes, error) {
return nil, errors.New("test remote error")
}
func errBaseRepoStubFn() (ghrepo.Interface, error) {
return nil, errors.New("test base repo error")
}

View file

@ -180,11 +180,3 @@ func isIncluded(value string, opts []string) bool {
}
return false
}
func CountSetFlags(flags *pflag.FlagSet) int {
count := 0
flags.Visit(func(f *pflag.Flag) {
count++
})
return count
}