diff --git a/pkg/cmd/agent-task/shared/capi.go b/pkg/cmd/agent-task/shared/capi.go index f064eac2e..36f206e34 100644 --- a/pkg/cmd/agent-task/shared/capi.go +++ b/pkg/cmd/agent-task/shared/capi.go @@ -1,13 +1,16 @@ package shared import ( + "errors" "regexp" "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" + prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" ) var uuidRE = regexp.MustCompile(`^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$`) +var agentSessionsPathRE = regexp.MustCompile(`^/agent-sessions/([a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12})$`) func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) { return func() (capi.CapiClient, error) { @@ -29,3 +32,20 @@ func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) { func IsSessionID(s string) bool { return uuidRE.MatchString(s) } + +// ParsePullRequestAgentSessionURL parses session ID from a pull request's agent +// session URL, which is of the form: +// +// https://github.com/OWNER/REPO/pull/NUMBER/agent-sessions/SESSION-ID +func ParsePullRequestAgentSessionURL(u string) (string, error) { + _, _, rest, err := prShared.ParseURL(u) + if err != nil { + return "", err + } + + match := agentSessionsPathRE.FindStringSubmatch(rest) + if match == nil { + return "", errors.New("not a valid agent session URL") + } + return match[1], nil +} diff --git a/pkg/cmd/agent-task/shared/capi_test.go b/pkg/cmd/agent-task/shared/capi_test.go index d6a106d1b..d29f50624 100644 --- a/pkg/cmd/agent-task/shared/capi_test.go +++ b/pkg/cmd/agent-task/shared/capi_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIsSession(t *testing.T) { @@ -18,3 +19,58 @@ func TestIsSession(t *testing.T) { assert.False(t, IsSessionID("000000000000000000000000000000000000")) assert.False(t, IsSessionID("00000000-0000-0000-0000-000000000000-extra")) } + +func TestParsePullRequestAgentSessionURL(t *testing.T) { + tests := []struct { + name string + url string + wantSessionID string + wantErr bool + }{ + { + name: "valid", + url: "https://github.com/OWNER/REPO/pull/123/agent-sessions/e2fa49d2-f164-4a56-ab99-498090b8fcdf", + wantSessionID: "e2fa49d2-f164-4a56-ab99-498090b8fcdf", + }, + { + name: "invalid session id", + url: "https://github.com/OWNER/REPO/pull/123/agent-sessions/fff", + wantErr: true, + }, + { + name: "no session id, trailing slash", + url: "https://github.com/OWNER/REPO/pull/123/agent-sessions/", + wantErr: true, + }, + { + name: "no session id", + url: "https://github.com/OWNER/REPO/pull/123/agent-sessions", + wantErr: true, + }, + { + name: "invalid pr url", + url: "https://github.com/OWNER/REPO/issues/123", + wantErr: true, + }, + { + name: "empty", + url: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sessionID, err := ParsePullRequestAgentSessionURL(tt.url) + + if tt.wantErr { + require.Error(t, err) + assert.Zero(t, sessionID) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantSessionID, sessionID) + }) + } +}