diff --git a/pkg/cmd/gist/edit/edit.go b/pkg/cmd/gist/edit/edit.go index a8fa4d683..27df4a104 100644 --- a/pkg/cmd/gist/edit/edit.go +++ b/pkg/cmd/gist/edit/edit.go @@ -5,7 +5,9 @@ import ( "encoding/json" "errors" "fmt" + "io/ioutil" "net/http" + "path/filepath" "sort" "strings" @@ -28,8 +30,9 @@ type EditOptions struct { Edit func(string, string, string, *iostreams.IOStreams) (string, error) - Selector string - Filename string + Selector string + EditFilename string + AddFilename string } func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Command { @@ -60,7 +63,9 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman return editRun(&opts) }, } - cmd.Flags().StringVarP(&opts.Filename, "filename", "f", "", "Select a file to edit") + + cmd.Flags().StringVarP(&opts.AddFilename, "add", "a", "", "Add a new file to the gist") + cmd.Flags().StringVarP(&opts.EditFilename, "filename", "f", "", "Select a file to edit") return cmd } @@ -85,6 +90,9 @@ func editRun(opts *EditOptions) error { gist, err := shared.GetGist(client, ghinstance.OverridableDefault(), gistID) if err != nil { + if errors.Is(err, shared.NotFoundErr) { + return fmt.Errorf("gist not found: %s", gistID) + } return err } @@ -97,10 +105,20 @@ func editRun(opts *EditOptions) error { return fmt.Errorf("You do not own this gist.") } + if opts.AddFilename != "" { + files, err := getFilesToAdd(opts.AddFilename) + if err != nil { + return err + } + + gist.Files = files + return updateGist(apiClient, ghinstance.OverridableDefault(), gist) + } + filesToUpdate := map[string]string{} for { - filename := opts.Filename + filename := opts.EditFilename candidates := []string{} for filename := range gist.Files { candidates = append(candidates, filename) @@ -222,3 +240,22 @@ func updateGist(apiClient *api.Client, hostname string, gist *shared.Gist) error return nil } + +func getFilesToAdd(file string) (map[string]*shared.GistFile, error) { + content, err := ioutil.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("failed to read file %s: %w", file, err) + } + + if len(content) == 0 { + return nil, errors.New("file contents cannot be empty") + } + + filename := filepath.Base(file) + return map[string]*shared.GistFile{ + filename: { + Filename: filename, + Content: string(content), + }, + }, nil +} diff --git a/pkg/cmd/gist/edit/edit_test.go b/pkg/cmd/gist/edit/edit_test.go index b9b711a38..73a53e523 100644 --- a/pkg/cmd/gist/edit/edit_test.go +++ b/pkg/cmd/gist/edit/edit_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "io/ioutil" "net/http" + "path/filepath" "testing" "github.com/cli/cli/internal/config" @@ -15,8 +16,26 @@ import ( "github.com/cli/cli/pkg/prompt" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func Test_getFilesToAdd(t *testing.T) { + fileToAdd := filepath.Join(t.TempDir(), "gist-test.txt") + err := ioutil.WriteFile(fileToAdd, []byte("hello"), 0600) + require.NoError(t, err) + + gf, err := getFilesToAdd(fileToAdd) + require.NoError(t, err) + + filename := filepath.Base(fileToAdd) + assert.Equal(t, map[string]*shared.GistFile{ + filename: { + Filename: filename, + Content: "hello", + }, + }, gf) +} + func TestNewCmdEdit(t *testing.T) { tests := []struct { name string @@ -34,8 +53,16 @@ func TestNewCmdEdit(t *testing.T) { name: "filename", cli: "123 --filename cool.md", wants: EditOptions{ - Selector: "123", - Filename: "cool.md", + Selector: "123", + EditFilename: "cool.md", + }, + }, + { + name: "add", + cli: "123 --add cool.md", + wants: EditOptions{ + Selector: "123", + AddFilename: "cool.md", }, }, } @@ -60,13 +87,18 @@ func TestNewCmdEdit(t *testing.T) { _, err = cmd.ExecuteC() assert.NoError(t, err) - assert.Equal(t, tt.wants.Filename, gotOpts.Filename) + assert.Equal(t, tt.wants.EditFilename, gotOpts.EditFilename) + assert.Equal(t, tt.wants.AddFilename, gotOpts.AddFilename) assert.Equal(t, tt.wants.Selector, gotOpts.Selector) }) } } func Test_editRun(t *testing.T) { + fileToAdd := filepath.Join(t.TempDir(), "gist-test.txt") + err := ioutil.WriteFile(fileToAdd, []byte("hello"), 0600) + require.NoError(t, err) + tests := []struct { name string opts *EditOptions @@ -74,13 +106,12 @@ func Test_editRun(t *testing.T) { httpStubs func(*httpmock.Registry) askStubs func(*prompt.AskStubber) nontty bool - wantErr bool - wantStderr string + wantErr string wantParams map[string]interface{} }{ { name: "no such gist", - wantErr: true, + wantErr: "gist not found: 1234", }, { name: "one file", @@ -163,7 +194,7 @@ func Test_editRun(t *testing.T) { as.StubOne("unix.md") as.StubOne("Cancel") }, - wantErr: true, + wantErr: "SilentError", gist: &shared.Gist{ ID: "1234", Files: map[string]*shared.GistFile{ @@ -208,8 +239,28 @@ func Test_editRun(t *testing.T) { }, Owner: &shared.GistOwner{Login: "octocat2"}, }, - wantErr: true, - wantStderr: "You do not own this gist.", + wantErr: "You do not own this gist.", + }, + { + name: "add file to existing gist", + gist: &shared.Gist{ + ID: "1234", + Files: map[string]*shared.GistFile{ + "sample.txt": { + Filename: "sample.txt", + Content: "bwhiizzzbwhuiiizzzz", + Type: "text/plain", + }, + }, + Owner: &shared.GistOwner{Login: "octocat"}, + }, + httpStubs: func(reg *httpmock.Registry) { + reg.Register(httpmock.REST("POST", "gists/1234"), + httpmock.StatusStringResponse(201, "{}")) + }, + opts: &EditOptions{ + AddFilename: fileToAdd, + }, }, } @@ -246,7 +297,7 @@ func Test_editRun(t *testing.T) { tt.opts.HttpClient = func() (*http.Client, error) { return &http.Client{Transport: reg}, nil } - io, _, _, _ := iostreams.Test() + io, _, stdout, stderr := iostreams.Test() io.SetStdoutTTY(!tt.nontty) io.SetStdinTTY(!tt.nontty) tt.opts.IO = io @@ -260,11 +311,8 @@ func Test_editRun(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := editRun(tt.opts) reg.Verify(t) - if tt.wantErr { - assert.Error(t, err) - if tt.wantStderr != "" { - assert.EqualError(t, err, tt.wantStderr) - } + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) return } assert.NoError(t, err) @@ -278,6 +326,9 @@ func Test_editRun(t *testing.T) { } assert.Equal(t, tt.wantParams, reqBody) } + + assert.Equal(t, "", stdout.String()) + assert.Equal(t, "", stderr.String()) }) } } diff --git a/pkg/cmd/gist/shared/shared.go b/pkg/cmd/gist/shared/shared.go index 04e63ce86..04fe2c33b 100644 --- a/pkg/cmd/gist/shared/shared.go +++ b/pkg/cmd/gist/shared/shared.go @@ -1,6 +1,7 @@ package shared import ( + "errors" "fmt" "net/http" "net/url" @@ -31,6 +32,8 @@ type Gist struct { Owner *GistOwner `json:"owner,omitempty"` } +var NotFoundErr = errors.New("not found") + func GetGist(client *http.Client, hostname, gistID string) (*Gist, error) { gist := Gist{} path := fmt.Sprintf("gists/%s", gistID) @@ -38,6 +41,10 @@ func GetGist(client *http.Client, hostname, gistID string) (*Gist, error) { apiClient := api.NewClientFromHTTP(client) err := apiClient.REST(hostname, "GET", path, nil, &gist) if err != nil { + var httpErr api.HTTPError + if errors.As(err, &httpErr) && httpErr.StatusCode == 404 { + return nil, NotFoundErr + } return nil, err }