diff --git a/pkg/cmd/gist/edit/edit.go b/pkg/cmd/gist/edit/edit.go index 566307b63..2e95b3ca4 100644 --- a/pkg/cmd/gist/edit/edit.go +++ b/pkg/cmd/gist/edit/edit.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net/http" - "net/url" "sort" "strings" @@ -69,11 +68,12 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman func editRun(opts *EditOptions) error { gistID := opts.Selector - u, err := url.Parse(opts.Selector) - if err == nil { - if strings.HasPrefix(u.Path, "/") { - gistID = u.Path[1:] + if strings.Contains(gistID, "/") { + id, err := shared.GistIDFromURL(gistID) + if err != nil { + return err } + gistID = id } client, err := opts.HttpClient() diff --git a/pkg/cmd/gist/shared/shared.go b/pkg/cmd/gist/shared/shared.go index 95d8bae31..f285621ab 100644 --- a/pkg/cmd/gist/shared/shared.go +++ b/pkg/cmd/gist/shared/shared.go @@ -3,6 +3,8 @@ package shared import ( "fmt" "net/http" + "net/url" + "strings" "time" "github.com/cli/cli/api" @@ -37,3 +39,20 @@ func GetGist(client *http.Client, hostname, gistID string) (*Gist, error) { return &gist, nil } + +func GistIDFromURL(gistURL string) (string, error) { + u, err := url.Parse(gistURL) + if err == nil && strings.HasPrefix(u.Path, "/") { + split := strings.Split(u.Path, "/") + + if len(split) > 2 { + return split[2], nil + } + + if len(split) == 2 && split[1] != "" { + return split[1], nil + } + } + + return "", fmt.Errorf("Invalid gist URL %s", u) +} diff --git a/pkg/cmd/gist/shared/shared_test.go b/pkg/cmd/gist/shared/shared_test.go new file mode 100644 index 000000000..0f80db690 --- /dev/null +++ b/pkg/cmd/gist/shared/shared_test.go @@ -0,0 +1,52 @@ +package shared + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_GetGistIDFromURL(t *testing.T) { + tests := []struct { + name string + url string + want string + wantErr bool + }{ + { + name: "url", + url: "https://gist.github.com/1234", + want: "1234", + }, + { + name: "url with username", + url: "https://gist.github.com/octocat/1234", + want: "1234", + }, + { + name: "url, specific file", + url: "https://gist.github.com/1234#file-test-md", + want: "1234", + }, + { + name: "invalid url", + url: "https://gist.github.com", + wantErr: true, + want: "Invalid gist URL https://gist.github.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, err := GistIDFromURL(tt.url) + if tt.wantErr { + assert.Error(t, err) + assert.EqualError(t, err, tt.want) + return + } + assert.NoError(t, err) + + assert.Equal(t, tt.want, id) + }) + } +} diff --git a/pkg/cmd/gist/view/view.go b/pkg/cmd/gist/view/view.go index fecb69adb..aae6b93ba 100644 --- a/pkg/cmd/gist/view/view.go +++ b/pkg/cmd/gist/view/view.go @@ -3,7 +3,6 @@ package view import ( "fmt" "net/http" - "net/url" "sort" "strings" @@ -71,11 +70,12 @@ func viewRun(opts *ViewOptions) error { return utils.OpenInBrowser(gistURL) } - u, err := url.Parse(opts.Selector) - if err == nil { - if strings.HasPrefix(u.Path, "/") { - gistID = u.Path[1:] + if strings.Contains(gistID, "/") { + id, err := shared.GistIDFromURL(gistID) + if err != nil { + return err } + gistID = id } client, err := opts.HttpClient() diff --git a/pkg/cmd/gist/view/view_test.go b/pkg/cmd/gist/view/view_test.go index 0ddc33181..0ddf55f49 100644 --- a/pkg/cmd/gist/view/view_test.go +++ b/pkg/cmd/gist/view/view_test.go @@ -90,11 +90,17 @@ func Test_viewRun(t *testing.T) { wantErr bool }{ { - name: "no such gist", + name: "no such gist", + opts: &ViewOptions{ + Selector: "1234", + }, wantErr: true, }, { name: "one file", + opts: &ViewOptions{ + Selector: "1234", + }, gist: &shared.Gist{ Files: map[string]*shared.GistFile{ "cicada.txt": { @@ -108,6 +114,7 @@ func Test_viewRun(t *testing.T) { { name: "filename selected", opts: &ViewOptions{ + Selector: "1234", Filename: "cicada.txt", }, gist: &shared.Gist{ @@ -126,6 +133,9 @@ func Test_viewRun(t *testing.T) { }, { name: "multiple files, no description", + opts: &ViewOptions{ + Selector: "1234", + }, gist: &shared.Gist{ Files: map[string]*shared.GistFile{ "cicada.txt": { @@ -142,6 +152,9 @@ func Test_viewRun(t *testing.T) { }, { name: "multiple files, description", + opts: &ViewOptions{ + Selector: "1234", + }, gist: &shared.Gist{ Description: "some files", Files: map[string]*shared.GistFile{ @@ -160,7 +173,8 @@ func Test_viewRun(t *testing.T) { { name: "raw", opts: &ViewOptions{ - Raw: true, + Selector: "1234", + Raw: true, }, gist: &shared.Gist{ Description: "some files", @@ -200,8 +214,6 @@ func Test_viewRun(t *testing.T) { io.SetStdoutTTY(true) tt.opts.IO = io - tt.opts.Selector = "1234" - t.Run(tt.name, func(t *testing.T) { err := viewRun(tt.opts) if tt.wantErr {