diff --git a/pkg/cmd/agent-task/shared/capi.go b/pkg/cmd/agent-task/shared/capi.go index f064eac2e..c61aa13a7 100644 --- a/pkg/cmd/agent-task/shared/capi.go +++ b/pkg/cmd/agent-task/shared/capi.go @@ -1,13 +1,19 @@ package shared import ( + "errors" + "fmt" "regexp" "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" + prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" ) -var uuidRE = regexp.MustCompile(`^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$`) +const uuidPattern = `[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}` + +var sessionIDRegexp = regexp.MustCompile(fmt.Sprintf("^%s$", uuidPattern)) +var agentSessionURLRegexp = regexp.MustCompile(fmt.Sprintf("^/agent-sessions/(%s)$", uuidPattern)) func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) { return func() (capi.CapiClient, error) { @@ -27,5 +33,22 @@ func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) { } func IsSessionID(s string) bool { - return uuidRE.MatchString(s) + return sessionIDRegexp.MatchString(s) +} + +// ParseSessionIDFromURL parses session ID from a pull request's agent session +// URL, which is of the form: +// +// `https://github.com/OWNER/REPO/pull/NUMBER/agent-sessions/SESSION-ID` +func ParseSessionIDFromURL(u string) (string, error) { + _, _, rest, err := prShared.ParseURL(u) + if err != nil { + return "", err + } + + match := agentSessionURLRegexp.FindStringSubmatch(rest) + if match == nil { + return "", errors.New("not a valid agent session URL") + } + return match[1], nil } diff --git a/pkg/cmd/agent-task/shared/capi_test.go b/pkg/cmd/agent-task/shared/capi_test.go index d6a106d1b..205d881c8 100644 --- a/pkg/cmd/agent-task/shared/capi_test.go +++ b/pkg/cmd/agent-task/shared/capi_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIsSession(t *testing.T) { @@ -18,3 +19,58 @@ func TestIsSession(t *testing.T) { assert.False(t, IsSessionID("000000000000000000000000000000000000")) assert.False(t, IsSessionID("00000000-0000-0000-0000-000000000000-extra")) } + +func TestParsePullRequestAgentSessionURL(t *testing.T) { + tests := []struct { + name string + url string + wantSessionID string + wantErr bool + }{ + { + name: "valid", + url: "https://github.com/OWNER/REPO/pull/123/agent-sessions/e2fa49d2-f164-4a56-ab99-498090b8fcdf", + wantSessionID: "e2fa49d2-f164-4a56-ab99-498090b8fcdf", + }, + { + name: "invalid session id", + url: "https://github.com/OWNER/REPO/pull/123/agent-sessions/fff", + wantErr: true, + }, + { + name: "no session id, trailing slash", + url: "https://github.com/OWNER/REPO/pull/123/agent-sessions/", + wantErr: true, + }, + { + name: "no session id", + url: "https://github.com/OWNER/REPO/pull/123/agent-sessions", + wantErr: true, + }, + { + name: "invalid pr url", + url: "https://github.com/OWNER/REPO/issues/123", + wantErr: true, + }, + { + name: "empty", + url: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sessionID, err := ParseSessionIDFromURL(tt.url) + + if tt.wantErr { + require.Error(t, err) + assert.Zero(t, sessionID) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantSessionID, sessionID) + }) + } +} diff --git a/pkg/cmd/agent-task/view/view.go b/pkg/cmd/agent-task/view/view.go index 787157ffe..024101b0e 100644 --- a/pkg/cmd/agent-task/view/view.go +++ b/pkg/cmd/agent-task/view/view.go @@ -80,6 +80,8 @@ func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Comman opts.SelectorArg = args[0] if shared.IsSessionID(opts.SelectorArg) { opts.SessionID = opts.SelectorArg + } else if sessionID, err := shared.ParseSessionIDFromURL(opts.SelectorArg); err == nil { + opts.SessionID = sessionID } } @@ -149,7 +151,7 @@ func viewRun(opts *ViewOptions) error { session = sess } else { - var resourceID int64 + var prID int64 var prURL string if opts.SelectorArg != "" { @@ -169,14 +171,14 @@ func viewRun(opts *ViewOptions) error { return fmt.Errorf("agent tasks are not supported on this host: %s", hostname) } - resourceID, prURL, err = capiClient.GetPullRequestDatabaseID(ctx, hostname, repo.RepoOwner(), repo.RepoName(), num) + prID, prURL, err = capiClient.GetPullRequestDatabaseID(ctx, hostname, repo.RepoOwner(), repo.RepoName(), num) if err != nil { return fmt.Errorf("failed to fetch pull request: %w", err) } } } - if resourceID == 0 { + if prID == 0 { findOptions := prShared.FindOptions{ Selector: opts.SelectorArg, Fields: []string{"id", "url", "fullDatabaseId"}, @@ -196,7 +198,7 @@ func viewRun(opts *ViewOptions) error { return fmt.Errorf("failed to parse pull request: %w", err) } - resourceID = databaseID + prID = databaseID prURL = pr.URL } @@ -204,7 +206,7 @@ func viewRun(opts *ViewOptions) error { // matching sessions to avoid hitting the API too many times, but it's // technically possible for a PR to be associated with lots of sessions // (i.e. above our selected limit). - sessions, err := capiClient.ListSessionsByResourceID(ctx, "pull", resourceID, defaultLimit) + sessions, err := capiClient.ListSessionsByResourceID(ctx, "pull", prID, defaultLimit) if err != nil { return fmt.Errorf("failed to list sessions for pull request: %w", err) } diff --git a/pkg/cmd/agent-task/view/view_test.go b/pkg/cmd/agent-task/view/view_test.go index a4636c369..bdff45793 100644 --- a/pkg/cmd/agent-task/view/view_test.go +++ b/pkg/cmd/agent-task/view/view_test.go @@ -46,6 +46,15 @@ func TestNewCmdList(t *testing.T) { SessionID: "00000000-0000-0000-0000-000000000000", }, }, + { + name: "PR agent-session URL arg tty", + tty: true, + args: "https://github.com/OWNER/REPO/pull/101/agent-sessions/00000000-0000-0000-0000-000000000000", + wantOpts: ViewOptions{ + SelectorArg: "https://github.com/OWNER/REPO/pull/101/agent-sessions/00000000-0000-0000-0000-000000000000", + SessionID: "00000000-0000-0000-0000-000000000000", + }, + }, { name: "non-session ID arg tty", tty: true, diff --git a/pkg/cmd/pr/edit/edit.go b/pkg/cmd/pr/edit/edit.go index c01364d24..b196546cd 100644 --- a/pkg/cmd/pr/edit/edit.go +++ b/pkg/cmd/pr/edit/edit.go @@ -97,7 +97,7 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman // needs to know the API host. If the command is run outside of // a git repo, we cannot instantiate the detector unless we have // already parsed the URL. - if baseRepo, _, err := shared.ParseURL(opts.SelectorArg); err == nil { + if baseRepo, _, _, err := shared.ParseURL(opts.SelectorArg); err == nil { opts.BaseRepo = func() (ghrepo.Interface, error) { return baseRepo, nil } diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index cb8237d58..7d66d60f3 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -112,7 +112,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err return nil, nil, errors.New("Find error: no fields specified") } - if repo, prNumber, err := ParseURL(opts.Selector); err == nil { + if repo, prNumber, _, err := ParseURL(opts.Selector); err == nil { f.prNumber = prNumber f.baseRefRepo = repo } @@ -300,32 +300,34 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err return pr, f.baseRefRepo, g.Wait() } -var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`) +var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)(.*$)`) -// ParseURL parses a pull request URL and returns the repository and pull -// request number. -func ParseURL(prURL string) (ghrepo.Interface, int, error) { +// ParseURL parses a pull request URL and returns the repository, pull request +// number, and any tailing path components. If there is no error, the returned +// repo is not nil and will have non-empty hostname. +func ParseURL(prURL string) (ghrepo.Interface, int, string, error) { if prURL == "" { - return nil, 0, fmt.Errorf("invalid URL: %q", prURL) + return nil, 0, "", fmt.Errorf("invalid URL: %q", prURL) } u, err := url.Parse(prURL) if err != nil { - return nil, 0, err + return nil, 0, "", err } if u.Scheme != "https" && u.Scheme != "http" { - return nil, 0, fmt.Errorf("invalid scheme: %s", u.Scheme) + return nil, 0, "", fmt.Errorf("invalid scheme: %s", u.Scheme) } m := pullURLRE.FindStringSubmatch(u.Path) if m == nil { - return nil, 0, fmt.Errorf("not a pull request URL: %s", prURL) + return nil, 0, "", fmt.Errorf("not a pull request URL: %s", prURL) } repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname()) prNumber, _ := strconv.Atoi(m[3]) - return repo, prNumber, nil + tail := m[4] + return repo, prNumber, tail, nil } var fullReferenceRE = regexp.MustCompile(`^(?:([^/]+)/([^/]+))#(\d+)$`) diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index 5e33ee876..470709480 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -21,6 +21,7 @@ func TestParseURL(t *testing.T) { arg string wantRepo ghrepo.Interface wantNum int + wantRest string wantErr string }{ { @@ -35,15 +36,46 @@ func TestParseURL(t *testing.T) { wantRepo: ghrepo.NewWithHost("owner", "repo", "example.com"), wantNum: 123, }, + { + name: "valid HTTP URL with rest", + arg: "http://example.com/owner/repo/pull/123/foo/bar", + wantRepo: ghrepo.NewWithHost("owner", "repo", "example.com"), + wantNum: 123, + wantRest: "/foo/bar", + }, + { + name: "valid HTTP URL with .patch as rest", + arg: "http://example.com/owner/repo/pull/123.patch", + wantRepo: ghrepo.NewWithHost("owner", "repo", "example.com"), + wantNum: 123, + wantRest: ".patch", + }, + { + name: "valid HTTP URL with a trailing slash", + arg: "http://example.com/owner/repo/pull/123/", + wantRepo: ghrepo.NewWithHost("owner", "repo", "example.com"), + wantNum: 123, + wantRest: "/", + }, { name: "empty URL", wantErr: "invalid URL: \"\"", }, + { + name: "no scheme", + arg: "github.com/owner/repo/pull/123", + wantErr: "invalid scheme: ", + }, { name: "invalid scheme", arg: "ftp://github.com/owner/repo/pull/123", wantErr: "invalid scheme: ftp", }, + { + name: "no hostname", + arg: "/owner/repo/pull/123", + wantErr: "invalid scheme: ", + }, { name: "incorrect path", arg: "https://github.com/owner/repo/issues/123", @@ -63,7 +95,7 @@ func TestParseURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - repo, num, err := ParseURL(tt.arg) + repo, num, rest, err := ParseURL(tt.arg) if tt.wantErr != "" { require.Error(t, err) @@ -73,6 +105,7 @@ func TestParseURL(t *testing.T) { require.NoError(t, err) require.Equal(t, tt.wantNum, num) + require.Equal(t, tt.wantRest, rest) require.NotNil(t, repo) require.True(t, ghrepo.IsSame(tt.wantRepo, repo)) })